https://github.com/GPflow/GPflow
Raw File
Tip revision: 2b0e60b4dec5ee701d4a6e5fc0053afbc007c969 authored by Artem Artemev on 17 June 2019, 08:11:06 UTC
Initial commit
Tip revision: 2b0e60b
__config__.py
import os

import numpy as np
import tensorflow as tf

__all__ = ["default_float", "default_jitter"]

_ENV_JITTER = "GPFLOW_JITTER"
_ENV_FLOAT = "GPFLOW_FLOAT"
_ENV_INT = "GPFLOW_INT"

_JITTER_VALUE = 1e-6
_FLOAT_VALUE = np.float32
_INT_VALUE = np.int32


class _Config:
    def __init__(self):
        self._int = os.getenv(_ENV_INT, default=_INT_VALUE)
        self._float = os.getenv(_ENV_FLOAT, default=_FLOAT_VALUE)
        self._jitter = os.getenv(_ENV_JITTER, default=_JITTER_VALUE)


__config = _Config()


def default_int():
    return __config._int


def default_float():
    return __config._float


def default_jitter():
    return __config._jitter


def set_default_int(value_type):
    try:
        tf.as_dtype(value_type)  # Test input value that it is eligable type.
        __config._int = value_type
    except TypeError:
        raise TypeError("Expected tf or np dtype argument")


def set_default_float(value_type):
    try:
        tf.as_dtype(value_type)  # Test input value that it is eligable type.
        __config._float = value_type
    except TypeError:
        raise TypeError("Expected tf or np dtype argument")


def set_default_jitter(value: float):
    if not (isinstance(value, (tf.Tensor, np.ndarray)) and len(value.shape) == 0) and \
            not isinstance(value, float):
        raise ValueError("Expected float32 or float64 scalar value")

    __config._jitter = value
back to top