https://github.com/GPflow/GPflow
Raw File
Tip revision: e9e5307c9711990392c063ce3fcdb9786236d854 authored by Sergio Diaz on 03 September 2019, 12:10:44 UTC
Parameter class state keeps transform class and not instance
Tip revision: e9e5307
utilities.py
import re
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Union

import numpy as np
import tensorflow as tf
from tabulate import tabulate

from ..config import summary_fmt
from ..base import Parameter

__all__ = [
    "set_trainable",
    "multiple_assign",
    "training_loop",
    "print_summary",
]


def set_trainable(model: tf.Module, flag: bool = False):
    """
    Set trainable flag for all `tf.Variable`s and `gpflow.Parameter`s in a module.
    """
    for variable in model.trainable_variables:
        variable._trainable = flag


def multiple_assign(input: tf.Module, vars_dict: Dict[str, tf.Tensor]):
    """
    Multiple assign takes a dictionary with new values. Dictionary keys are paths to the
    `tf.Variable`s or `gpflow.Parameters` of the input module.

    :param input: `tf.Module`.
    :param vars_dict: a dictionary with keys of the form "module.path.to.variable" and new value tensors.
    """
    reference_var_dict = leaf_components(input)
    for path, value in vars_dict.items():
        reference_var_dict[path].assign(value)


def training_loop(closure: Callable[..., tf.Tensor],
                  optimizer: Optional[tf.optimizers.Optimizer] = None,
                  var_list: List[tf.Variable] = None,
                  maxiter=1e3,
                  jit=False):
    """
    Simple generic training loop. At each iteration uses a GradientTape to compute
    the gradients of a loss function with respect to a set of variables.

    :param closure: Callable that constructs a loss function based on data and model being trained
    :param optimizer: tf.optimizers or tf.keras.optimizers that updates variables by applying the
        corresponding loss gradients. Adam is a default optimizer with default settings.
    :param var_list: List of model variables to be learnt during training
    :param maxiter: Maximum number of
    :return:
    """

    optimizer = tf.optimizers.Adam() if optimizer is None else optimizer

    def optimization_step():
        with tf.GradientTape() as tape:
            tape.watch(var_list)
            loss = closure()
            grads = tape.gradient(loss, var_list)
        optimizer.apply_gradients(zip(grads, var_list))

    if jit:
        optimization_step = tf.function(optimization_step)

    for _ in range(int(maxiter)):
        optimization_step()


def print_summary(module: tf.Module, fmt: str = None):
    """
    Prints a summary of the parameters and variables contained in a tf.Module.
    """

    fmt = fmt if fmt is not None else summary_fmt()
    column_names = ['name', 'class', 'transform', 'trainable', 'shape', 'dtype', 'value']

    def get_name(v):
        return v.__class__.__name__

    def get_transform(v):
        if hasattr(v, "transform") and v.transform is not None:
            return v.transform.__class__.__name__
        return None

    merged_leaf_components = _merge_leaf_components(leaf_components(module))

    column_values = [[
        path,
        get_name(variable),
        get_transform(variable),
        variable.trainable,
        variable.shape,
        variable.dtype.name,
        _str_tensor_value(variable.numpy())
    ] for path, variable in merged_leaf_components.items()]

    if fmt == "notebook":
        from IPython.core.display import display, HTML
        tab = tabulate(column_values, headers=column_names, tablefmt="html")
        display(HTML(tab))
    else:
        print(tabulate(column_values, headers=column_names, tablefmt=fmt))


def leaf_components(input: tf.Module):
    return _get_leaf_components(input)


def _merge_leaf_components(
        input: Dict[str, Union[tf.Tensor, Parameter]]) -> Dict[str, Union[tf.Tensor, Parameter]]:
    if len(set(input.values())) == len(input):
        return input
    tmp_dict = dict()
    for key, item in input.items():
        if item in tmp_dict:
            tmp_dict[item] = f"{tmp_dict[item]}\n{key}"
        else:
            tmp_dict[item] = key
    return {key: item for item, key in tmp_dict.items()}


def _get_leaf_components(input: tf.Module, prefix: Optional[str] = None):
    """
    Returns a list of tuples each corresponding to a gpflow.Parameter or tf.Variable in the each
    submodules of a given tf.Module. Each tuple consists of an specific Parameter (or Variable) and
    its relative path inside the module, which is constructed recursively by adding a prefix with
    the path to the current module. Designed to be used as a helper for the method 'print_summary'.

    :param module: tf.Module including keras.Model, keras.layers.Layer and gpflow.Module.
    :param prefix: string containing the relative path to module, by default set to None.
    :return:
    """
    if not isinstance(input, tf.Module):
        raise TypeError("Input object expected to have `tf.Module` type")

    prefix = input.__class__.__name__ if prefix is None else prefix
    var_dict = dict()

    for key, submodule in vars(input).items():
        if key in tf.Module._TF_MODULE_IGNORED_PROPERTIES:
            continue
        elif isinstance(submodule, Parameter) or isinstance(submodule, tf.Variable):
            var_dict[f"{prefix}.{key}"] = submodule
        elif isinstance(submodule, tf.Module):
            submodule_var = _get_leaf_components(submodule, prefix=f"{prefix}.{key}")
            var_dict.update(submodule_var)

    return var_dict


@lru_cache()
def _first_three_elements_regexp():
    num_re = r"[+\-]?(?:0|[1-9]\d*)(?:\.\d*)?(?:[eE][+\-]?\d+)?"
    pat_re = rf"^(?:(\[+)\s*)?({num_re})(?:\s+({num_re})(?:\s+({num_re}))?)?.*?"
    return re.compile(pat_re)


def _str_tensor_value(value: np.ndarray):
    value_str = str(value)
    if value.size <= 3:
        return value_str

    max_chars = 500
    value_str = value_str[:max_chars]
    regexp = _first_three_elements_regexp()
    match = regexp.match(value_str)
    assert match is not None
    brackets, elem1, elem2, elem3 = match.groups()

    out = f"{elem1}"
    if elem2 is not None:
        out = f"{out}{f', {elem2}'}"
        if elem3 is not None:
            out = f"{out}{f', {elem3}'}"
    if brackets is not None:
        out = f"{brackets}{out}..."

    return out
back to top