Revision 458c6a6d1725c6d7b21773e762ceca01faa0bb23 authored by william cowley on 15 October 2021, 17:10:45 UTC, committed by william cowley on 15 October 2021, 17:10:45 UTC
1 parent 5e92b2e
Raw File
training_mixins.py
# Copyright 2020 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module provides mixin classes to be used in conjunction with inheriting
from gpflow.models.BayesianModel (or its subclass gpflow.models.GPModel).

They provide a unified interface to obtain closures that return the training
loss, to be passed as the first argument to the minimize() method of the
optimizers defined in TensorFlow and GPflow.

All TrainingLossMixin classes assume that self._training_loss()
(which is provided by the BayesianModel base class), will be available. Note
that new models only need to implement the maximum_log_likelihood_objective
method that is defined as abstract in BayesianModel.

There are different mixins depending on whether the model already contains the
training data (InternalDataTrainingLossMixin), or requires it to be passed in
to the objective function (ExternalDataTrainingLossMixin).
"""

import abc
from typing import Callable, Iterator, Optional, Tuple, TypeVar, Union

import numpy as np
import tensorflow as tf
from tensorflow.python.data.ops.iterator_ops import OwnedIterator as DatasetOwnedIterator

InputData = Union[tf.Tensor]
OutputData = Union[tf.Tensor]
RegressionData = Tuple[InputData, OutputData]
Data = TypeVar("Data", RegressionData, InputData, OutputData)


class InternalDataTrainingLossMixin:
    """
    Mixin utility for training loss methods for models that own their own data. It provides

      - a uniform API for the training loss :meth:`training_loss`
      - a convenience method :meth:`training_loss_closure` for constructing the closure expected by
        various optimizers, namely :class:`gpflow.optimizers.Scipy` and subclasses of
        `tf.optimizers.Optimizer`.

    See :class:`ExternalDataTrainingLossMixin` for an equivalent mixin for models that do **not**
    own their own data.
    """

    def training_loss(self) -> tf.Tensor:
        """
        Returns the training loss for this model.
        """
        return self._training_loss()

    def training_loss_closure(self, *, compile=True) -> Callable[[], tf.Tensor]:
        """
        Convenience method. Returns a closure which itself returns the training loss. This closure
        can be passed to the minimize methods on :class:`gpflow.optimizers.Scipy` and subclasses of
        `tf.optimizers.Optimizer`.

        :param compile: If `True` (default), compile the training loss function in a TensorFlow graph
            by wrapping it in tf.function()
        """
        if compile:
            return tf.function(self.training_loss)
        return self.training_loss


class ExternalDataTrainingLossMixin:
    """
    Mixin utility for training loss methods for models that do **not** own their own data.
    It provides

      - a uniform API for the training loss :meth:`training_loss`
      - a convenience method :meth:`training_loss_closure` for constructing the closure expected by
        various optimizers, namely :class:`gpflow.optimizers.Scipy` and subclasses of
        `tf.optimizers.Optimizer`.

    See :class:`InternalDataTrainingLossMixin` for an equivalent mixin for models that **do** own
    their own data.
    """

    def training_loss(self, data: Data) -> tf.Tensor:
        """
        Returns the training loss for this model.

        :param data: the data to be used for computing the model objective.
        """
        return self._training_loss(data)

    def training_loss_closure(
        self,
        data: Union[Data, DatasetOwnedIterator],
        *,
        compile=True,
    ) -> Callable[[], tf.Tensor]:
        """
        Returns a closure that computes the training loss, which by default is
        wrapped in tf.function(). This can be disabled by passing `compile=False`.

        :param data: the data to be used by the closure for computing the model
            objective. Can be the full dataset or an iterator, e.g.
            `iter(dataset.batch(batch_size))`, where dataset is an instance of
            tf.data.Dataset.
        :param compile: if True, wrap training loss in tf.function()
        """
        training_loss = self.training_loss

        if isinstance(data, DatasetOwnedIterator):
            if compile:
                input_signature = [data.element_spec]
                training_loss = tf.function(training_loss, input_signature=input_signature)

            def closure():
                batch = next(data)
                return training_loss(batch)

        else:

            def closure():
                return training_loss(data)

            if compile:
                closure = tf.function(closure)

        return closure
back to top