https://github.com/GPflow/GPflow
Revision 61088fd9496ee8f574dbadeb4f067719a696f1c7 authored by st-- on 19 June 2018, 13:19:35 UTC, committed by Artem Artemev on 19 June 2018, 13:19:35 UTC
Notebooks are slow to execute. Additional utility functions help to detect where notebook are run and control number of iterations (optimization, standard loops and cetera), therefore minimizing spent time on running notebook integration tests.
1 parent bb08f22
Raw File
Tip revision: 61088fd9496ee8f574dbadeb4f067719a696f1c7 authored by st-- on 19 June 2018, 13:19:35 UTC
Speed up notebooks (#789)
Tip revision: 61088fd
optimizer.py
# Copyright 2017 Artem Artemev @awav
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=no-self-use
# pylint: disable=too-few-public-methods

import abc


class Optimizer:
    @abc.abstractmethod
    def make_optimize_tensor(self, model, session=None, var_list=None, **kwargs):
        """
        Make optimization tensor.
        The `make_optimize_tensor` method builds optimization tensor and initializes
        all necessary variables created by optimizer.

            :param model: GPflow model.
            :param session: Tensorflow session.
            :param var_list: List of variables for training.
            :param kwargs: Dictionary of extra parameters necessary for building
                optimizer tensor.
            :return: Tensorflow optimization tensor or operation.
        """
        pass

    @abc.abstractmethod
    def minimize(self, model, session=None, var_list=None, feed_dict=None,
                 maxiter=1000, initialize=True, anchor=True, **kwargs):
        raise NotImplementedError()

    @staticmethod
    def _gen_var_list(model, var_list):
        var_list = var_list or []
        all_vars = list(set(model.trainable_tensors).union(var_list))
        return sorted(all_vars, key=lambda x: x.name)

    @staticmethod
    def _gen_feed_dict(model, feed_dict):
        feed_dict = feed_dict or {}
        model_feeds = {} if model.feeds is None else model.feeds
        feed_dict.update(model_feeds)
        if feed_dict == {}:
            return None
        return feed_dict
back to top