Revision c2e94baae46c7b97cdd666b6d4a6fa3c6bdd8e9f authored by Artem Artemev on 05 November 2017, 23:34:06 UTC, committed by Artem Artemev on 05 November 2017, 23:34:06 UTC
1 parent 7c8e935
Raw File
decors.py
# Copyright 2017 Artem Artemev @awav
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import contextlib

import tensorflow as tf

from .core.base import GPflowError
from .core.base import Build
from .core.base import AutoBuildTag
from .core.node import Node
from .core.autoflow import AutoFlow
from .core.tensor_converter import TensorConverter

from .params import Parameterized


def name_scope(name=None):
    def name_scope_wrapper(method):
        @functools.wraps(method)
        def runnable(*args, **kwargs):
            scope_name = name if name is not None else method.__name__
            with tf.name_scope(scope_name):
                return method(*args, **kwargs)
        return runnable
    return name_scope_wrapper


def params_as_tensors(method):
    @functools.wraps(method)
    def tensor_mode_wrapper(obj, *args, **kwargs):
        if not isinstance(obj, Parameterized):
            raise GPflowError(
                'Tensor mode works only for parmeterized object.')
        prev_value = _params_as_tensors_enter(obj, True)
        try:
            result = method(obj, *args, **kwargs)
        finally:
            _params_as_tensors_exit(obj, prev_value)
        return result
    return tensor_mode_wrapper


def autobuild(switch=True):
    def autobuild_wrapper(method):
        @functools.wraps(method)
        def runnable(*args, **kwargs):
            if not switch:
                __execute_autobuild__ = AutoBuildTag.IGNORE
            return method(*args, **kwargs)
        return runnable
    return autobuild_wrapper


@contextlib.contextmanager
def params_as_tensors_for(obj, convert=True):
    prev_value = _params_as_tensors_enter(obj, convert)
    try:
        yield
    finally:
        _params_as_tensors_exit(obj, prev_value)


def autoflow(*af_args, **af_kwargs):
    def autoflow_wrapper(method):
        @functools.wraps(method)
        def runnable(obj, *args, **kwargs):
            if not isinstance(obj, Node):
                raise GPflowError(
                    'AutoFlow works only with node-like objects.')
            if obj.is_built_coherence(obj.graph) is Build.NO:
                raise GPflowError('Not built with "{graph}".'.format(graph=obj.graph))
            name = method.__name__
            store = AutoFlow.get_autoflow(obj, name)
            session = kwargs.pop('session', None)
            session = obj.enquire_session(session=session)
            if not store:
                scope_name = _name_scope_name(obj, name)
                with session.graph.as_default(), tf.name_scope(scope_name):
                    _setup_storage(store, *af_args, **af_kwargs)
                    _build_method(method, obj, store)
            return _session_run(session, obj, store, *args, **kwargs)
        return runnable
    return autoflow_wrapper


def _params_as_tensors_enter(obj, convert=True):
    name = TensorConverter.__tensor_mode__
    attr_value = getattr(obj, name, None)
    setattr(obj, name, convert)
    return attr_value


def _params_as_tensors_exit(obj, previous):
    name = TensorConverter.__tensor_mode__
    if previous is not None:
        setattr(obj, name, previous)
    else:
        delattr(obj, name)


def _setup_storage(store, *args, **_kwargs):
    store['arguments'] = [tf.placeholder(*arg) for arg in args]


def _name_scope_name(obj, name):
    return '/'.join(['autoflow', obj.name, name])


def _session_run(session, obj, store, *args, **kwargs):
    feed_dict_key = 'feed_dict'
    if feed_dict_key not in kwargs:
        kwargs[feed_dict_key] = {}
    feed_dict = kwargs.get(feed_dict_key)
    feed_dict.update(dict(zip(store['arguments'], args)))
    if obj.feeds:
        feed_dict.update(obj.feeds)
    initialize = kwargs.pop('initialize', True)
    obj.initialize(session=session, force=initialize)
    return session.run(store['result'], **kwargs)


def _build_method(method, obj, store):
    store['result'] = method(obj, *store['arguments'])
back to top