Revision e24fd815cdfb8c654249da4576aeff6c2ce5a8ea authored by vdutor on 10 September 2020, 15:35:12 UTC, committed by vdutor on 10 September 2020, 15:35:12 UTC
1 parent 61b2e1c
Raw File
util.py
from typing import Callable, Union

import numpy as np
import tensorflow as tf

from ..config import default_float
from ..inducing_variables import InducingPoints, InducingVariables
from .model import BayesianModel
from .training_mixins import Data, ExternalDataTrainingLossMixin


def inducingpoint_wrapper(
    inducing_variable: Union[InducingVariables, tf.Tensor, np.ndarray]
) -> InducingVariables:
    """
    This wrapper allows transparently passing either an InducingVariables
    object or an array specifying InducingPoints positions.
    """
    if not isinstance(inducing_variable, InducingVariables):
        inducing_variable = InducingPoints(inducing_variable)
    return inducing_variable


def _assert_equal_data(data1, data2):
    if isinstance(data1, tf.Tensor) and isinstance(data2, tf.Tensor):
        tf.debugging.assert_equal(data1, data2)
    else:
        for v1, v2 in zip(data1, data2):
            tf.debugging.assert_equal(v1, v2)


def training_loss_closure(
    model: BayesianModel, data: Data, **closure_kwargs
) -> Callable[[], tf.Tensor]:
    if isinstance(model, ExternalDataTrainingLossMixin):
        return model.training_loss_closure(data, **closure_kwargs)
    else:
        _assert_equal_data(model.data, data)
        return model.training_loss_closure(**closure_kwargs)


def training_loss(model: BayesianModel, data: Data) -> tf.Tensor:
    if isinstance(model, ExternalDataTrainingLossMixin):
        return model.training_loss(data)
    else:
        _assert_equal_data(model.data, data)
        return model.training_loss()


def maximum_log_likelihood_objective(model: BayesianModel, data: Data) -> tf.Tensor:
    if isinstance(model, ExternalDataTrainingLossMixin):
        return model.maximum_log_likelihood_objective(data)
    else:
        _assert_equal_data(model.data, data)
        return model.maximum_log_likelihood_objective()


def data_input_to_tensor(structure):
    """
    Converts non-tensor elements of a structure to TensorFlow tensors retaining the structure itself.
    The function doesn't keep original element's dtype and forcefully converts
    them to GPflow's default float type.
    """

    def convert_to_tensor(elem):
        if tf.is_tensor(elem):
            return elem
        elif isinstance(elem, np.ndarray):
            return tf.convert_to_tensor(elem)
        return tf.convert_to_tensor(elem, dtype=default_float())

    return tf.nest.map_structure(convert_to_tensor, structure)
back to top