Raw File
test_methods.py
# Copyright 2016 the GPflow authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import List

import numpy as np
import pytest
from numpy.testing import assert_array_equal, assert_array_less

import gpflow
from gpflow.base import AnyNDArray

# ------------------------------------------
# Data classes: storing constants
# ------------------------------------------


@dataclass(frozen=True)
class Datum:
    rng: np.random.RandomState = np.random.RandomState(0)
    X: AnyNDArray = rng.randn(100, 2)
    Y: AnyNDArray = rng.randn(100, 1)
    Z: AnyNDArray = rng.randn(10, 2)
    Xs: AnyNDArray = rng.randn(10, 2)
    lik = gpflow.likelihoods.Gaussian()
    kernel = gpflow.kernels.Matern32()


default_datum = Datum()


_gp_models: List[gpflow.models.GPModel] = [
    gpflow.models.VGP((default_datum.X, default_datum.Y), default_datum.kernel, default_datum.lik),
    gpflow.models.GPMC((default_datum.X, default_datum.Y), default_datum.kernel, default_datum.lik),
    gpflow.models.SGPMC(
        (default_datum.X, default_datum.Y),
        default_datum.kernel,
        default_datum.lik,
        inducing_variable=default_datum.Z,
    ),
    gpflow.models.SGPR(
        (default_datum.X, default_datum.Y),
        default_datum.kernel,
        inducing_variable=default_datum.Z,
    ),
    gpflow.models.GPR((default_datum.X, default_datum.Y), default_datum.kernel),
    gpflow.models.GPRFITC(
        (default_datum.X, default_datum.Y),
        default_datum.kernel,
        inducing_variable=default_datum.Z,
    ),
]

_state_less_gp_models: List[gpflow.models.GPModel] = [
    gpflow.models.SVGP(default_datum.kernel, default_datum.lik, inducing_variable=default_datum.Z)
]


@pytest.mark.parametrize("model", _state_less_gp_models + _gp_models)
def test_methods_predict_f(model: gpflow.models.GPModel) -> None:
    mf, vf = model.predict_f(default_datum.Xs)
    assert_array_equal(mf.shape, vf.shape)
    assert_array_equal(mf.shape, (10, 1))
    assert_array_less(np.full_like(vf, -1e-6), vf)


@pytest.mark.parametrize("model", _state_less_gp_models + _gp_models)
def test_methods_predict_y(model: gpflow.models.GPModel) -> None:
    mf, vf = model.predict_y(default_datum.Xs)
    assert_array_equal(mf.shape, vf.shape)
    assert_array_equal(mf.shape, (10, 1))
    assert_array_less(np.full_like(vf, -1e-6), vf)


@pytest.mark.parametrize("model", _state_less_gp_models + _gp_models)
def test_methods_predict_log_density(model: gpflow.models.GPModel) -> None:
    rng = Datum().rng
    Ys = rng.randn(10, 1)
    d = model.predict_log_density((default_datum.Xs, Ys))
    assert_array_equal(d.shape, (10,))
back to top