import os import numpy as np import tensorflow as tf __all__ = ["default_float", "default_jitter"] _ENV_JITTER = "GPFLOW_JITTER" _ENV_FLOAT = "GPFLOW_FLOAT" _ENV_INT = "GPFLOW_INT" _JITTER_VALUE = 1e-6 _FLOAT_VALUE = np.float32 _INT_VALUE = np.int32 class _Config: def __init__(self): self._int = os.getenv(_ENV_INT, default=_INT_VALUE) self._float = os.getenv(_ENV_FLOAT, default=_FLOAT_VALUE) self._jitter = os.getenv(_ENV_JITTER, default=_JITTER_VALUE) __config = _Config() def default_int(): return __config._int def default_float(): return __config._float def default_jitter(): return __config._jitter def set_default_int(value_type): try: tf.as_dtype(value_type) # Test input value that it is eligable type. __config._int = value_type except TypeError: raise TypeError("Expected tf or np dtype argument") def set_default_float(value_type): try: tf.as_dtype(value_type) # Test input value that it is eligable type. __config._float = value_type except TypeError: raise TypeError("Expected tf or np dtype argument") def set_default_jitter(value: float): if not (isinstance(value, (tf.Tensor, np.ndarray)) and len(value.shape) == 0) and \ not isinstance(value, float): raise ValueError("Expected float32 or float64 scalar value") __config._jitter = value