Revision f3154bfca90196e06d42c6b27d44d35e27e58563 authored by Artem Artemev on 03 December 2018, 00:36:31 UTC, committed by GitHub on 03 December 2018, 00:36:31 UTC
2 parent s 897112b + 7f0cac1
Raw File
actions.py
# Copyright 2018 Artem Artemev @awav
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import collections
import inspect
import itertools
from timeit import default_timer as timer
from typing import Any, Callable, Dict, Optional, Sequence, Union

import numpy as np
import tensorflow as tf

from . import get_default_session
from .models import Model


class Watcher:
    """
    Watcher is a sort of time manager.
    
    It holds information when timer is started, and when it is stopped.
    If elapsed time is asked and timer is not stopped, then different between
    current time and started is returned. Otherwise, different between
    start and stop time.
    """

    def __init__(self):
        self._start = None
        self._stop = None

    def start(self):
        """
        Set start time [µs].
        """
        self._start = timer()
    
    def stop(self):
        """
        Set stop time [µs].
        """
        self._stop = timer()
    
    @property
    def elapsed(self):
        """
        Elapsed time [µs] between start and stop timestamps. If stop is empty then
        returned time is difference between start and current timestamp.
        """
        if self._stop is None:
            return timer() - self._start
        return self._stop - self._start
    

class ActionContext:
    """
    Action context contains sharable data between sequence of action runs.

        :param owner: Action to which context originally belongs to.
        :param session: Tensorflow session which will be used when necessary by
            all dependent actions ran by owner.
    """

    def __init__(self, owner: 'Action', session: Optional[tf.Session] = None):
        if session is None:
            session = tf.get_default_session()
            session = get_default_session() if session is None else session
        self.session = session
        self.owner = owner
    
    @property
    def iteration(self) -> int:
        """
        Current iteration number.
        In fact, this is proxy method for Loop's owner interaction.

            :return: Iteration number or None, in case when owner doesn't
                have any references to iteration.
        """
        return _get_attr(self.owner, iteration=None)
    
    @property
    def time_spent(self) -> float:
        """
        Elapsed time since owner started running dependent actions.

            :return: Spent time [µs] by owner on running up to the call.
        """
        return self.owner.watcher.elapsed


class Action(metaclass=abc.ABCMeta):
    """
    Action base abstraction for wrapping functionality inside function-like container.
    
    Actions are functions with extra utilities. Each action has watcher
    which gauge how much time action spent on execution. Action may contain
    information about iteration, but usually 
    """
    @abc.abstractmethod
    def run(self, context: ActionContext) -> None:
        """
        Run action's execution. Must be implemented by decendants.

            :param context: Required action context.
        """
        pass
    
    @property
    def watcher(self) -> Watcher:
        """
        Gives an access to action's watcher.

            :return: Action's watcher instance.
        """
        if not hasattr(self, "_watcher"):
            self._watcher = Watcher()
        return self._watcher
    
    def __call__(self, context: Optional[ActionContext] = None) -> None:
        """
        Action call method.
        The `action()` activates watcher, then calls `run` method and deactivates
        watcher even when `run` method raised an exception.

            :param context: Optional action context value. When argument is none
                the standard ActionContext is created and callee action assigns
                itself as a context owner.
        """
        context = ActionContext(self) if context is None else context
        try:
            self.watcher.start()
            self.run(context)
        finally:
            self.watcher.stop()


# ======================
# Action implementations
# ======================

class Loop(Action):
    """
    Action loop abstraction.
    Default loop is an infinite with starting point at zero and update step one. 
    It evaluates body action as many times as it is specified by start, stop and step.
    Loop provides two control signals - 'Break' and 'Continue'.
    These signals must be raised as a standard exception. Depending on which signal
    was raised the loop either stops execution or continues from beginning.

        :param action: Action or list of actions, which will be run within loop.
        :param stop: Maximum loop iteration.
        :param start: Starting point for loop iterations.
        :param step: Integer interval between iterations.
    """

    class Continue(Exception):
        """Continue signal returns execution to the beginning of the loop."""
        pass

    class Break(Exception):
        """Break signal stops loop immediately."""
        pass

    def __init__(self,
            action: Action,
            stop: Optional[int] = None,
            start: int = 0,
            step: int = 1) -> None:
        self._action = _try_convert_action(action)
        self.start = start
        self.stop = stop
        self.step = step
    
    @property
    def iteration(self) -> int:
        """Active iteration number."""
        return _get_attr(self, _iteration=None)
    
    def with_iteration(self, iteration: int):
        """
        Set iteration number.
        
            :param iteration: Iteration interger number.
            :return: Loop itself.
        """
        self._iteration = iteration
        return self

    def with_action(self, action: Action) -> 'Loop':
        """
        Set loop body action.
            
            :param action: Loop body action.
            :return: Loop itself.
        """
        self._action = _try_convert_action(action)
        return self

    def with_settings(self,
            stop: Optional[int] = None,
            start: int = 0,
            step: int = 1) -> 'Loop':
        """
        Set start, stop and step loop configuration.

            :param stop: Looop stop iteration integer. If None then loop
                becomes infinite.
            :param start: Loop iteration start integer.
            :param step: Loop iteration interval integer.
            :return: Loop itself.
        """
        self.start = start
        self.stop = stop
        self.step = step
        return self

    def run(self, context: ActionContext):
        """
        Run performs loop iterations.

            :param context: Action context.
        """
        iterator = itertools.count(start=self.start, step=self.step)
        for i in iterator:
            self.with_iteration(i)
            if self.stop is not None and i >= self.stop:
                break
            try:
                self._action(context)
            except Loop.Continue:
                continue
            except Loop.Break:
                break
    

class Group(Action):
    """
    Group is an ordered list of actions.
    Group runs list of actions step by step, executing them separately
    in single for loop.

        :param actions: List of actions.
    """
    def __init__(self, *actions: Sequence[Action]) -> None:
        self._actions = actions
    
    def run(self, context: ActionContext):
        """
        Run executes list of cached actions in the for loop.
        """
        for action in self._actions:
            action(context)


class Condition(Action):
    """
    Condition actions is a branching mechanism.
    Depending on boolean output of callable conditional function, either
    true or false passed action is executed.

        :param condition_fn: Boolean function with single input, current context.
        :param action_true: True action, which is executed when output of
            conditional function returns true.
        :param action_false: False action, which is executed when output of
            conditional function returns false.
    """
    def __init__(self,
            condition_fn: Callable[[ActionContext], bool],
            action_true: Action,
            action_false: Optional[Action] = None) -> None:
        self.condition_fn = condition_fn
        self.action_false = action_false
        self.action_true = action_true

    def run(self, context: ActionContext):
        cond = self.condition_fn(context)
        if cond:
            return self.action_true(context)
        elif self.action_false is None:
            return None
        return self.action_false(context)


class Optimization(Action):
    """
    Optimization is an action wrapper for GPflow optimizers.
    It contains an optimizer, a model and TensorFlow minimization tensor
    generated by the optimizer and model.
    This action may be used as single optimization step.
    """

    def with_optimizer(self, optimizer) -> 'Optimization':
        """
        Replace optimizer.

            :param optimizer: GPflow optimizer.
            :return: Optimization instance self reference.
        """
        self._optimizer = optimizer
        return self


    def with_model(self, model: Model) -> 'Optimization':
        """
        Replace model.

            :param model: GPflow model.
            :return: Optimization instance self reference.
        """
        self._model = model
        return self
    
    def with_optimizer_tensor(self, tensor: Union[tf.Tensor, tf.Operation]) -> 'Optimization':
        """
        Replace optimizer tensor.

            :param model: Tensorflow tensor.
            :return: Optimization instance self reference.
        """
        self._optimizer_tensor = tensor
        return self
    
    def with_run_kwargs(self, **kwargs: Dict[str, Any]) -> 'Optimization':
        """
        Replace Tensorflow session run kwargs.
        Check Tensorflow session run [documentation](https://www.tensorflow.org/api_docs/python/tf/Session).

            :param kwargs: Dictionary of tensors as keys and numpy arrays or
                primitive python types as values.
            :return: Optimization instance self reference.
        """
        self._run_kwargs = kwargs
        return self

    @property
    def model(self) -> Model:
        """The `model` is an attribute for getting optimization's model."""
        return _get_attr(self, _model=None)

    @property
    def optimizer_tensor(self) -> Union[tf.Tensor, tf.Operation]:
        """
        The `optimizer_tensor` is an attribute for getting optimization's
        optimizer tensor.
        """
        return _get_attr(self, _optimizer_tensor=None)
    
    @property
    def run_kwargs(self):
        """The `run_kwargs` is an attribute for getting session run's kwargs."""
        return _get_attr(self, _run_kwargs=None)
    
    def run(self, context: ActionContext) -> None:
        context.session.run(self.optimizer_tensor, **self.run_kwargs)


def _get_attr(obj, **attr):
    assert len(attr) == 1
    return getattr(obj, list(attr.keys())[0], list(attr.values())[0])


def _is_action(a):
    return isinstance(a, Action) or (callable(a) and len(inspect.signature(a).parameters) == 1)


def _try_convert_action(action):
    if _is_action(action):
        return action
    if not isinstance(action, list):
        raise ValueError('Expected either action type of list of actions.')
    for a in action:
        if _is_action(a):
            continue
        raise ValueError('List consists of non homogeneous action types.')
    return Group(*action)
back to top