https://github.com/GPflow/GPflow
Raw File
Tip revision: 97d2d56a002f1c9c1d91c90e6bf75e17e3b66dd7 authored by Artem Artemev on 14 November 2019, 14:55:16 UTC
Change regexp
Tip revision: 97d2d56
__config__.py
import contextlib
import enum
import os
from dataclasses import dataclass, field, replace
from typing import Optional

import numpy as np
import tabulate
import tensorflow as tf

__all__ = [
    "Config", "default_summary_fmt", "default_float", "default_jitter", "default_int", "set_default_summary_fmt",
    "set_default_float", "set_default_jitter", "set_default_int", "as_context", "config", "set_config",
    "default_positive_minimum", "set_default_positive_minimum"
]

__config = None


class _Values(enum.Enum):
    """Setting's names collection with default values. The `name` method returns name
    of the environment variable. E.g. for `SUMMARY_FMT` field the environment variable
    will be `GPFLOW_SUMMARY_FMT`."""
    INT = np.int32
    FLOAT = np.float64
    POSITIVE_MINIMUM = None
    SUMMARY_FMT = None
    JITTER = 1e-6

    @property
    def name(self):
        return f"GPFLOW_{super().name}"


def default(value: _Values):
    """Checks if """
    return os.getenv(value.name, default=value.value)


@dataclass(frozen=True)
class Config:
    """
    Immutable object for storing global GPflow settings

    Args:
        int: Integer data type, int32 or int64.
        float: Float data type, float32 or float64
        jitter: Jitter value. Mainly used for for making badly conditioned matrices more stable.
            Default value is `1e-6`.
        positive_minimum: Lower level for the positive transformation.
        summary_fmt: Summary format for module printing.
    """

    int: type = field(default_factory=lambda: default(_Values.INT))
    float: type = field(default_factory=lambda: default(_Values.FLOAT))
    jitter: float = field(default_factory=lambda: default(_Values.JITTER))
    positive_minimum: float = field(default_factory=lambda: default(_Values.POSITIVE_MINIMUM))
    summary_fmt: str = field(default_factory=lambda: default(_Values.SUMMARY_FMT))


def config() -> Config:
    """Returns current active config."""
    return __config


def default_summary_fmt():
    return config().summary_fmt


def default_int():
    return config().int


def default_float():
    return config().float


def default_jitter():
    return config().jitter


def default_positive_minimum():
    return config().positive_minimum


def set_config(new_config: Config):
    """Update GPflow config"""
    global __config
    __config = new_config


def set_default_int(value_type):
    try:
        tf_dtype = tf.as_dtype(value_type)  # Test that it's a tensorflow-valid dtype
    except TypeError:
        raise TypeError(f"{value_type} is not a valid tf or np dtype")

    if not tf_dtype.is_integer:
        raise TypeError(f"{value_type} is not an integer dtype")

    set_config(replace(config(), int=value_type))


def set_default_float(value_type):
    try:
        tf_dtype = tf.as_dtype(value_type)  # Test that it's a tensorflow-valid dtype
    except TypeError:
        raise TypeError(f"{value_type} is not a valid tf or np dtype")

    if not tf_dtype.is_floating:
        raise TypeError(f"{value_type} is not a float dtype")

    set_config(replace(config(), float=value_type))


def set_default_jitter(value: float):
    if not (isinstance(value, (tf.Tensor, np.ndarray)) and len(value.shape) == 0) and \
            not isinstance(value, float):
        raise TypeError("Expected float32 or float64 scalar value")

    if value < 0:
        raise ValueError("Jitter must be non-negative")

    set_config(replace(config(), jitter=value))


def set_default_summary_fmt(value: str):
    formats = tabulate.tabulate_formats + ['notebook', None]
    if value not in formats:
        raise ValueError(f"Summary does not support '{value}' format")

    set_config(replace(config(), summary_fmt=value))


def set_default_positive_minimum(value: float):
    if not (isinstance(value, (tf.Tensor, np.ndarray)) and len(value.shape) == 0) and \
            not isinstance(value, float):
        raise TypeError("Expected float32 or float64 scalar value")

    if value < 0:
        raise ValueError("Value must be non-negative")

    set_config(replace(config(), positive_minimum=value))


@contextlib.contextmanager
def as_context(temporary_config: Optional[Config] = None):
    '''Ensure that global configs defaults, with a context manager. Useful
    for testing.'''
    current_config = config()
    temporary_config = replace(current_config) if temporary_config is None else temporary_config
    try:
        set_config(temporary_config)
        yield
    finally:
        set_config(current_config)


# Set global config.
set_config(Config())
back to top