https://github.com/GPflow/GPflow
Tip revision: 768f64426d6129b93f2159ac1ab9eb2cd8844f23 authored by Sergio Diaz on 18 March 2019, 16:13:45 UTC
Merge branch 'sergio_pasc/gpflow-2.0/move-quadrature-tests' of github.com:GPflow/GPflow into sergio_pasc/gpflow-2.0/move-quadrature-tests
Merge branch 'sergio_pasc/gpflow-2.0/move-quadrature-tests' of github.com:GPflow/GPflow into sergio_pasc/gpflow-2.0/move-quadrature-tests
Tip revision: 768f644
util.py
import copy
import logging
from typing import List, Union
import numpy as np
import tensorflow as tf
NoneType = type(None)
def create_logger(name=None):
return logging.getLogger('Temporary Logger Solution')
def default_jitter_eye(num_rows: int, num_columns: int = None, value: float = None) -> float:
value = default_jitter() if value is None else value
num_rows = int(num_rows)
num_columns = int(num_columns) if num_columns is not None else num_columns
return tf.eye(num_rows, num_columns=num_columns, dtype=default_float()) * value
def default_jitter() -> float:
return 1e-6
def default_float() -> float:
return np.float64
def default_int() -> int:
return np.int32
def leading_transpose(tensor: tf.Tensor, perm: List[Union[int, type(...)]]) -> tf.Tensor:
"""
Transposes tensors with leading dimensions. Leading dimensions in
permutation list represented via ellipsis `...`.
When leading dimensions are found, `transpose` method
considers them as a single grouped element indexed by 0 in `perm` list. So, passing
`perm=[-2, ..., -1]`, you assume that your input tensor has [..., A, B] shape,
and you want to move leading dims between A and B dimensions.
Dimension indices in permutation list can be negative or positive. Valid positive
indices start from 1 up to the tensor rank, viewing leading dimensions `...` as zero
index.
Example:
a = tf.random.normal((1, 2, 3, 4, 5, 6))
b = leading_transpose(a, [5, -3, ..., -2])
sess.run(b).shape
output> (6, 4, 1, 2, 3, 5)
:param tensor: TensorFlow tensor.
:param perm: List of permutation indices.
:returns: TensorFlow tensor.
:raises: ValueError when `...` cannot be found.
"""
perm = copy.copy(perm)
idx = perm.index(...)
perm[idx] = 0
rank = tf.rank(tensor)
perm_tf = perm % rank
leading_dims = tf.range(rank - len(perm) + 1)
perm = tf.concat([perm_tf[:idx], leading_dims, perm_tf[idx+1:]], 0)
return tf.transpose(tensor, perm)