https://github.com/GPflow/GPflow
Revision 5a5504dd83fd282362c157016eb168603db3db90 authored by Artem Artemev on 31 January 2019, 09:57:51 UTC, committed by Artem Artemev on 31 January 2019, 09:57:51 UTC
1 parent d55e452
Raw File
Tip revision: 5a5504dd83fd282362c157016eb168603db3db90 authored by Artem Artemev on 31 January 2019, 09:57:51 UTC
Add callback to scipy optimizer
Tip revision: 5a5504d
optimize.py
from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union

import tensorflow as tf

tfe = tf.contrib.eager


InputData = Union[Tuple[tf.Tensor, tf.Tensor], tf.Tensor]
Variables = List[tf.Variable]

LossCallback = Callable[..., tf.Tensor]
StepCallback = Callable[[int, tf.Tensor, List[tf.Tensor]], None]

Optimizer = Union[tf.train.Optimizer, ]


def create_iterator(*args, batch_size=None, buffer_size=1000, shuffle=True, repeat=True):
    """
    Args:
        *args: Arguments
        batch_size: Number of elements in a batch.
        buffer_size: Number of TODO.
        shuffle: TODO.
        repeat: TODO.
    Return:
        Creates iterator over input data.
    """
    ds = tf.data.Dataset.from_tensor_slices(args)
    if shuffle:
        ds = ds.shuffle(buffer_size)
    if batch_size is not None:
        ds = ds.batch(batch_size)
    if repeat:
        ds = ds.repeat()
    return ds.make_one_shot_iterator()


def loss_gradients(loss_cb: LossCallback, variables: Variables):
    with tf.GradientTape() as tape:
        loss = loss_cb()
    grads = tape.gradient(loss, variables)
    return loss, grads


def optimize(loss_cb: LossCallback,
             optimizer: tf.train.Optimizer,
             variables: List[tfe.Variable],
             steps: int,
             step_cb: Optional[StepCallback] = None):
    for iteration in range(steps):
        loss, grads = loss_gradients(loss_cb, variables)
        optimizer.apply_gradients(zip(grads, variables))
        if callable(step_cb):
            step_cb(iteration, loss, grads)


# @contextmanager
# def unconstrain_variables(variables: List[tfe.Variable]):
#     def switch(constrained: bool = False):
#         for v in variables:
#             v.is_constrained = constrained
#     switch(False)
#     try:
#         yield
#     finally:
#         switch(True)
back to top