https://github.com/GPflow/GPflow
Revision 5a5504dd83fd282362c157016eb168603db3db90 authored by Artem Artemev on 31 January 2019, 09:57:51 UTC, committed by Artem Artemev on 31 January 2019, 09:57:51 UTC
1 parent d55e452
Tip revision: 5a5504dd83fd282362c157016eb168603db3db90 authored by Artem Artemev on 31 January 2019, 09:57:51 UTC
Add callback to scipy optimizer
Add callback to scipy optimizer
Tip revision: 5a5504d
optimize.py
from contextlib import contextmanager
from typing import Callable, List, Optional, Tuple, Union
import tensorflow as tf
tfe = tf.contrib.eager
InputData = Union[Tuple[tf.Tensor, tf.Tensor], tf.Tensor]
Variables = List[tf.Variable]
LossCallback = Callable[..., tf.Tensor]
StepCallback = Callable[[int, tf.Tensor, List[tf.Tensor]], None]
Optimizer = Union[tf.train.Optimizer, ]
def create_iterator(*args, batch_size=None, buffer_size=1000, shuffle=True, repeat=True):
"""
Args:
*args: Arguments
batch_size: Number of elements in a batch.
buffer_size: Number of TODO.
shuffle: TODO.
repeat: TODO.
Return:
Creates iterator over input data.
"""
ds = tf.data.Dataset.from_tensor_slices(args)
if shuffle:
ds = ds.shuffle(buffer_size)
if batch_size is not None:
ds = ds.batch(batch_size)
if repeat:
ds = ds.repeat()
return ds.make_one_shot_iterator()
def loss_gradients(loss_cb: LossCallback, variables: Variables):
with tf.GradientTape() as tape:
loss = loss_cb()
grads = tape.gradient(loss, variables)
return loss, grads
def optimize(loss_cb: LossCallback,
optimizer: tf.train.Optimizer,
variables: List[tfe.Variable],
steps: int,
step_cb: Optional[StepCallback] = None):
for iteration in range(steps):
loss, grads = loss_gradients(loss_cb, variables)
optimizer.apply_gradients(zip(grads, variables))
if callable(step_cb):
step_cb(iteration, loss, grads)
# @contextmanager
# def unconstrain_variables(variables: List[tfe.Variable]):
# def switch(constrained: bool = False):
# for v in variables:
# v.is_constrained = constrained
# switch(False)
# try:
# yield
# finally:
# switch(True)
![swh spinner](/static/img/swh-spinner.gif)
Computing file changes ...