https://github.com/GPflow/GPflow
Raw File
Tip revision: 2099f7dbbe09cb9896f7fb098c9d9aef5800b851 authored by ST John on 18 March 2020, 10:57:22 UTC
Merge branch 'develop' of github.com:GPflow/GPflow into tf2.0-compatible
Tip revision: 2099f7d
ops.py
import copy
from typing import List, Optional, Union

import tensorflow as tf
import numpy as np


def cast(
    value: Union[tf.Tensor, np.ndarray], dtype: tf.DType, name: Optional[str] = None
) -> tf.Tensor:
    if not tf.is_tensor(value):
        # TODO(awav): Release TF2.2 resolves this issue
        # workaround for https://github.com/tensorflow/tensorflow/issues/35938
        return tf.convert_to_tensor(value, dtype, name=name)
    return tf.cast(value, dtype, name=name)


def eye(num: int, value: tf.Tensor, dtype: Optional[tf.DType] = None) -> tf.Tensor:
    if dtype is not None:
        value = cast(value, dtype)
    return tf.linalg.diag(tf.fill([num], value))


def leading_transpose(
    tensor: tf.Tensor, perm: List[Union[int, type(...)]], leading_dim: int = 0
) -> tf.Tensor:
    """
    Transposes tensors with leading dimensions. Leading dimensions in
    permutation list represented via ellipsis `...`.
    When leading dimensions are found, `transpose` method
    considers them as a single grouped element indexed by 0 in `perm` list. So, passing
    `perm=[-2, ..., -1]`, you assume that your input tensor has [..., A, B] shape,
    and you want to move leading dims between A and B dimensions.
    Dimension indices in permutation list can be negative or positive. Valid positive
    indices start from 1 up to the tensor rank, viewing leading dimensions `...` as zero
    index.
    Example:
        a = tf.random.normal((1, 2, 3, 4, 5, 6))
            # [..., A, B, C],
            # where A is 1st element,
            # B is 2nd element and
            # C is 3rd element in
            # permutation list,
            # leading dimensions are [1, 2, 3]
            # which are 0th element in permutation
            # list
        b = leading_transpose(a, [3, -3, ..., -2])  # [C, A, ..., B]
        sess.run(b).shape
        output> (6, 4, 1, 2, 3, 5)
    :param tensor: TensorFlow tensor.
    :param perm: List of permutation indices.
    :returns: TensorFlow tensor.
    :raises: ValueError when `...` cannot be found.
    """
    perm = copy.copy(perm)
    idx = perm.index(...)
    perm[idx] = leading_dim

    rank = tf.rank(tensor)
    perm_tf = perm % rank

    leading_dims = tf.range(rank - len(perm) + 1)
    perm = tf.concat([perm_tf[:idx], leading_dims, perm_tf[idx + 1 :]], 0)
    return tf.transpose(tensor, perm)


def broadcasting_elementwise(op, a, b):
    """
    Apply binary operation `op` to every pair in tensors `a` and `b`.

    :param op: binary operator on tensors, e.g. tf.add, tf.substract
    :param a: tf.Tensor, shape [n_1, ..., n_a]
    :param b: tf.Tensor, shape [m_1, ..., m_b]
    :return: tf.Tensor, shape [n_1, ..., n_a, m_1, ..., m_b]
    """
    flatres = op(tf.reshape(a, [-1, 1]), tf.reshape(b, [1, -1]))
    return tf.reshape(flatres, tf.concat([tf.shape(a), tf.shape(b)], 0))


def square_distance(X, X2):
    """
    Returns ||X - X2ᵀ||²
    Due to the implementation and floating-point imprecision, the
    result may actually be very slightly negative for entries very
    close to each other.

    This function can deal with leading dimensions in X and X2. 
    In the sample case, where X and X2 are both 2 dimensional, 
    for example, X is [N, D] and X2 is [M, D], then a tensor of shape 
    [N, M] is returned. If X is [N1, S1, D] and X2 is [N2, S2, D] 
    then the output will be [N1, S1, N2, S2].
    """
    if X2 is None:
        Xs = tf.reduce_sum(tf.square(X), axis=-1, keepdims=True)
        dist = -2 * tf.matmul(X, X, transpose_b=True)
        dist += Xs + tf.linalg.adjoint(Xs)
        return dist
    Xs = tf.reduce_sum(tf.square(X), axis=-1)
    X2s = tf.reduce_sum(tf.square(X2), axis=-1)
    dist = -2 * tf.tensordot(X, X2, [[-1], [-1]])
    dist += broadcasting_elementwise(tf.add, Xs, X2s)
    return dist


def pca_reduce(X: tf.Tensor, Q: tf.Tensor) -> tf.Tensor:
    """
    A helpful function for linearly reducing the dimensionality of the data X
    to Q.
    :param X: data array of size N (number of points) x D (dimensions)
    :param Q: Number of latent dimensions, Q < D
    :return: PCA projection array of size N x Q.
    """
    if Q > X.shape[1]:  # pragma: no cover
        raise ValueError("Cannot have more latent dimensions than observed")
    if isinstance(X, tf.Tensor):
        X = X.numpy()
        # TODO why not use tf.linalg.eigh?
    evals, evecs = np.linalg.eigh(np.cov(X.T))
    W = evecs[:, -Q:]
    return (X - X.mean(0)).dot(W)


# def pca_reduce(data: tf.Tensor, latent_dim: tf.Tensor) -> tf.Tensor:
#     """
#     A helpful function for linearly reducing the dimensionality of the data X
#     to Q.
#     :param X: data array of size N (number of points) x D (dimensions)
#     :param Q: Number of latent dimensions, Q < D
#     :return: PCA projection array of size [N, Q].
#     """
#     assert latent_dim <= data.shape[1], 'Cannot have more latent dimensions than observed'
#     x_cov = tfp.stats.covariance(data)
#     evals, evecs = tf.linalg.eigh(x_cov)
#     W = evecs[:, -latent_dim:]
#     return (data - tf.reduce_mean(data, axis=0, keepdims=True)) @ W
back to top