To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Heritage persistent IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.
from typing import Callable, List, Optional import tensorflow as tf def set_trainable(model: tf.Module, flag: bool = False): for variable in model.trainable_variables: variable._trainable = flag def training_loop(closure: Callable[..., tf.Tensor], optimizer: Optional[tf.optimizers.Optimizer] = None, var_list: List[tf.Variable] = None, maxiter=1e3, jit=False): """ Simple generic training loop. At each iteration uses a GradientTape to compute the gradients of a loss function with respect to a set of variables. :param closure: Callable that constructs a loss function based on data and model being trained :param optimizer: tf.optimizers or tf.keras.optimizers that updates variables by applying the corresponding loss gradients. Adam is a default optimizer with default settings. :param var_list: List of model variables to be learnt during training :param maxiter: Maximum number of :return: """ optimizer = tf.optimizers.Adam() if optimizer is None else optimizer def optimization_step(): with tf.GradientTape() as tape: tape.watch(var_list) loss = closure() grads = tape.gradient(loss, var_list) optimizer.apply_gradients(zip(grads, var_list)) if jit: optimization_step = tf.function(optimization_step) for _ in range(int(maxiter)): optimization_step()