https://github.com/GPflow/GPflow
Raw File
Tip revision: 2099f7dbbe09cb9896f7fb098c9d9aef5800b851 authored by ST John on 18 March 2020, 10:57:22 UTC
Merge branch 'develop' of github.com:GPflow/GPflow into tf2.0-compatible
Tip revision: 2099f7d
__config__.py
import contextlib
import enum
import os
from dataclasses import dataclass, field, replace
from typing import Dict, Optional

import numpy as np
import tabulate
import tensorflow as tf
import tensorflow_probability as tfp

__all__ = [
    "Config",
    "as_context",
    "config",
    "set_config",
    "default_float",
    "set_default_float",
    "default_int",
    "set_default_int",
    "default_jitter",
    "set_default_jitter",
    "default_positive_bijector",
    "set_default_positive_bijector",
    "default_positive_minimum",
    "set_default_positive_minimum",
    "default_summary_fmt",
    "set_default_summary_fmt",
    "positive_bijector_type_map",
]

__config = None


class _Values(enum.Enum):
    """Setting's names collection with default values. The `name` method returns name
    of the environment variable. E.g. for `SUMMARY_FMT` field the environment variable
    will be `GPFLOW_SUMMARY_FMT`."""

    INT = np.int32
    FLOAT = np.float64
    POSITIVE_BIJECTOR = "softplus"
    POSITIVE_MINIMUM = None
    SUMMARY_FMT = "fancy_grid"
    JITTER = 1e-6

    @property
    def name(self):
        return f"GPFLOW_{super().name}"


def _default(value: _Values):
    """Checks if value is set in the environment."""
    return os.getenv(value.name, default=value.value)


def _default_numeric_type_factory(valid_types, enum_key, type_name):
    value = _default(enum_key)
    if value in valid_types.values():
        return value
    if value not in valid_types:
        raise TypeError(f"Config cannot recognize {type_name} type.")
    return valid_types[value]


def _default_int_factory():
    valid_types = dict(int16=np.int16, int32=np.int32, int64=np.int64)
    return _default_numeric_type_factory(valid_types, _Values.INT, "int")


def _default_float_factory():
    valid_types = dict(float16=np.float16, float32=np.float32, float64=np.float64)
    return _default_numeric_type_factory(valid_types, _Values.FLOAT, "float")


def _default_jitter_factory():
    try:
        value = float(_default(_Values.JITTER))
    except ValueError:
        raise TypeError("Config cannot set the jitter value with non float type.")
    return value


def _default_positive_bijector_factory():
    bijector_type = _default(_Values.POSITIVE_BIJECTOR)
    if bijector_type not in positive_bijector_type_map().keys():
        raise TypeError(
            "Config cannot set the passed value as a default positive bijector."
            f"Available options: {set(positive_bijector_type_map().keys())}"
        )
    return bijector_type


def _default_positive_minimum_factory():
    try:
        value = _default(_Values.POSITIVE_MINIMUM)
        if value is not None:
            value = float(value)
    except ValueError:
        raise TypeError("Config cannot set the positive_minimum value with non float type.")
    return value


def _default_summary_fmt_factory():
    return _default(_Values.SUMMARY_FMT)


@dataclass(frozen=True)
class Config:
    """
    Immutable object for storing global GPflow settings

    Args:
        int: Integer data type, int32 or int64.
        float: Float data type, float32 or float64
        jitter: Jitter value. Mainly used for for making badly conditioned matrices more stable.
            Default value is `1e-6`.
        positive_bijector: Method for positive bijector, either "softplus" or "exp".
            Default is "softplus".
        positive_minimum: Lower level for the positive transformation.
        summary_fmt: Summary format for module printing.
    """

    int: type = field(default_factory=_default_int_factory)
    float: type = field(default_factory=_default_float_factory)
    jitter: float = field(default_factory=_default_jitter_factory)
    positive_bijector: str = field(default_factory=_default_positive_bijector_factory)
    positive_minimum: float = field(default_factory=_default_positive_minimum_factory)
    summary_fmt: str = field(default_factory=_default_summary_fmt_factory)


def config() -> Config:
    """Returns current active config."""
    return __config


def default_int():
    return config().int


def default_float():
    return config().float


def default_jitter():
    return config().jitter


def default_positive_bijector():
    return config().positive_bijector


def default_positive_minimum():
    return config().positive_minimum


def default_summary_fmt():
    return config().summary_fmt


def set_config(new_config: Config):
    """Update GPflow config"""
    global __config
    __config = new_config


def set_default_int(value_type):
    try:
        tf_dtype = tf.as_dtype(value_type)  # Test that it's a tensorflow-valid dtype
    except TypeError:
        raise TypeError(f"{value_type} is not a valid tf or np dtype")

    if not tf_dtype.is_integer:
        raise TypeError(f"{value_type} is not an integer dtype")

    set_config(replace(config(), int=tf_dtype.as_numpy_dtype))


def set_default_float(value_type):
    try:
        tf_dtype = tf.as_dtype(value_type)  # Test that it's a tensorflow-valid dtype
    except TypeError:
        raise TypeError(f"{value_type} is not a valid tf or np dtype")

    if not tf_dtype.is_floating:
        raise TypeError(f"{value_type} is not a float dtype")

    set_config(replace(config(), float=tf_dtype.as_numpy_dtype))


def set_default_jitter(value: float):
    if not (
        isinstance(value, (tf.Tensor, np.ndarray)) and len(value.shape) == 0
    ) and not isinstance(value, float):
        raise TypeError("Expected float32 or float64 scalar value")

    if value < 0:
        raise ValueError("Jitter must be non-negative")

    set_config(replace(config(), jitter=value))


def set_default_positive_bijector(value: str):
    type_map = positive_bijector_type_map()
    if isinstance(value, str):
        value = value.lower()
    if value not in type_map:
        raise ValueError(f"`{value}` not in set of valid bijectors: {sorted(type_map)}")

    set_config(replace(config(), positive_bijector=value))


def set_default_positive_minimum(value: float):
    if not (
        isinstance(value, (tf.Tensor, np.ndarray)) and len(value.shape) == 0
    ) and not isinstance(value, float):
        raise TypeError("Expected float32 or float64 scalar value")

    if value < 0:
        raise ValueError("Value must be non-negative")

    set_config(replace(config(), positive_minimum=value))


def set_default_summary_fmt(value: str):
    formats = tabulate.tabulate_formats + ["notebook", None]
    if value not in formats:
        raise ValueError(f"Summary does not support '{value}' format")

    set_config(replace(config(), summary_fmt=value))


def positive_bijector_type_map() -> Dict[str, type]:
    return {
        "exp": tfp.bijectors.Exp,
        "softplus": tfp.bijectors.Softplus,
    }


@contextlib.contextmanager
def as_context(temporary_config: Optional[Config] = None):
    """Ensure that global configs defaults, with a context manager. Useful for testing."""
    current_config = config()
    temporary_config = replace(current_config) if temporary_config is None else temporary_config
    try:
        set_config(temporary_config)
        yield
    finally:
        set_config(current_config)


# Set global config.
set_config(Config())
back to top