https://github.com/GPflow/GPflow
Tip revision: 2b0e60b4dec5ee701d4a6e5fc0053afbc007c969 authored by Artem Artemev on 17 June 2019, 08:11:06 UTC
Initial commit
Initial commit
Tip revision: 2b0e60b
training.py
from typing import Callable, List, Optional
import tensorflow as tf
def set_trainable(model: tf.Module, flag: bool = False):
for variable in model.trainable_variables:
variable._trainable = flag
def training_loop(closure: Callable[..., tf.Tensor],
optimizer: Optional[tf.optimizers.Optimizer] = None,
var_list: List[tf.Variable] = None,
maxiter=1e3,
jit=False):
"""
Simple generic training loop. At each iteration uses a GradientTape to compute
the gradients of a loss function with respect to a set of variables.
:param closure: Callable that constructs a loss function based on data and model being trained
:param optimizer: tf.optimizers or tf.keras.optimizers that updates variables by applying the
corresponding loss gradients. Adam is a default optimizer with default settings.
:param var_list: List of model variables to be learnt during training
:param maxiter: Maximum number of
:return:
"""
optimizer = tf.optimizers.Adam() if optimizer is None else optimizer
def optimization_step():
with tf.GradientTape() as tape:
tape.watch(var_list)
loss = closure()
grads = tape.gradient(loss, var_list)
optimizer.apply_gradients(zip(grads, var_list))
if jit:
optimization_step = tf.function(optimization_step)
for _ in range(int(maxiter)):
optimization_step()