https://github.com/tensorly/tensorly
Revision c08103ac61030a97f15f88cb8a9bf56eea207117 authored by Jean Kossaifi on 23 August 2019, 17:25:22 UTC, committed by GitHub on 23 August 2019, 17:25:22 UTC
1 parent de8ad5f
Tip revision: c08103ac61030a97f15f88cb8a9bf56eea207117 authored by Jean Kossaifi on 23 August 2019, 17:25:22 UTC
README: use new tucker_tensor syntax
README: use new tucker_tensor syntax
Tip revision: c08103a
tensorflow_backend.py
try:
import tensorflow as tf
except ImportError as error:
message = ('Impossible to import TensorFlow.\n'
'To use TensorLy with the TensorFlow backend, '
'you must first install TensorFlow!')
raise ImportError(message) from error
import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution(device_policy=tfe.DEVICE_PLACEMENT_SILENT)
import numpy as np
from . import Backend
class TensorflowBackend(Backend):
backend_name = 'tensorflow'
@staticmethod
def context(tensor):
return {'dtype': tensor.dtype}
@staticmethod
def tensor(data, dtype=np.float32, device=None, device_id=None):
if isinstance(data, tf.Tensor):
return data
out = tf.constant(data, dtype=dtype)
return out.gpu(device_id) if device == 'gpu' else out
@staticmethod
def is_tensor(tensor):
return isinstance(tensor, tf.Tensor)
@staticmethod
def to_numpy(tensor):
if isinstance(tensor, np.ndarray):
return tensor
elif isinstance(tensor, tf.Tensor):
return tensor.numpy()
else:
return tensor
@staticmethod
def ndim(tensor):
return len(tensor.get_shape()._dims)
@staticmethod
def shape(tensor):
return tuple(tensor.shape.as_list())
@staticmethod
def arange(start, stop=None, step=1, dtype=np.float32):
if stop is None:
stop = start
start = 0
return tf.range(start=start, limit=stop, delta=step, dtype=dtype)
def clip(self, tensor_, a_min=None, a_max=None, inplace=False):
if a_min is not None:
a_min = self.tensor(a_min, **self.context(tensor_))
else:
a_min = tf.reduce_min(tensor_)
if a_max is not None:
a_max = self.tensor(a_max, **self.context(tensor_))
else:
a_max = tf.reduce_max(tensor_)
return tf.clip_by_value(tensor_, clip_value_min=a_min, clip_value_max=a_max)
def moveaxis(self, tensor, source, target):
axes = list(range(self.ndim(tensor)))
if source < 0: source = axes[source]
if target < 0: target = axes[target]
try:
axes.pop(source)
except IndexError:
raise ValueError('Source should verify 0 <= source < tensor.ndim'
'Got %d' % source)
try:
axes.insert(target, source)
except IndexError:
raise ValueError('Destination should verify 0 <= destination < tensor.ndim'
'Got %d' % target)
return tf.transpose(tensor, axes)
@staticmethod
def norm(tensor, order=2, axis=None):
if order == 'inf':
order = np.inf
res = tf.norm(tensor, ord=order, axis=axis)
if res.shape == ():
return res.numpy()
return res
def dot(self, tensor1, tensor2):
return tf.tensordot(tensor1, tensor2, axes=([self.ndim(tensor1) - 1], [0]))
@staticmethod
def conj(x, *args, **kwargs):
"""WARNING: IDENTITY FUNCTION (does nothing)
This backend currently does not support complex tensors
"""
return x
@staticmethod
def solve(lhs, rhs):
squeeze = False
if rhs.ndim == 1:
squeeze = [-1]
rhs = tf.reshape(rhs, (-1, 1))
res = tf.linalg.solve(lhs, rhs)
if squeeze:
res = tf.squeeze(res, squeeze)
return res
@staticmethod
def truncated_svd(matrix, n_eigenvecs=None):
"""Computes an SVD on `matrix`
Parameters
----------
matrix : 2D-array
n_eigenvecs : int, optional, default is None
if specified, number of eigen[vectors-values] to return
Returns
-------
U : 2D-array
of shape (matrix.shape[0], n_eigenvecs)
contains the right singular vectors
S : 1D-array
of shape (n_eigenvecs, )
contains the singular values of `matrix`
V : 2D-array
of shape (n_eigenvecs, matrix.shape[1])
contains the left singular vectors
"""
dim_1, dim_2 = matrix.shape
if dim_1 <= dim_2:
min_dim = dim_1
else:
min_dim = dim_2
if n_eigenvecs is None or n_eigenvecs > min_dim:
full_matrices = True
else:
full_matrices = False
S, U, V = tf.svd(matrix, full_matrices=full_matrices)
U, S, V = U[:, :n_eigenvecs], S[:n_eigenvecs], tf.transpose(V)[:n_eigenvecs, :]
return U, S, V
@property
def SVD_FUNS(self):
return {'numpy_svd': self.partial_svd,
'truncated_svd': self.truncated_svd}
_FUN_NAMES = [
# source_fun, target_fun
(np.int32, 'int32'),
(np.int64, 'int64'),
(np.float32, 'float32'),
(np.float64, 'float64'),
(tf.ones, 'ones'),
(tf.zeros, 'zeros'),
(tf.diag, 'diag'),
(tf.zeros_like, 'zeros_like'),
(tf.eye, 'eye'),
(tf.reshape, 'reshape'),
(tf.transpose, 'transpose'),
(tf.where, 'where'),
(tf.sign, 'sign'),
(tf.abs, 'abs'),
(tf.sqrt, 'sqrt'),
(tf.linalg.qr, 'qr'),
#(tf.linalg.solve, 'solve'),
(tf.argmin, 'argmin'),
(tf.argmax, 'argmax'),
(tf.stack, 'stack'),
(tf.identity, 'copy'),
(tf.concat, 'concatenate'),
(tf.stack, 'stack'),
(tf.reduce_min, 'min'),
(tf.reduce_max, 'max'),
(tf.reduce_mean, 'mean'),
(tf.reduce_sum, 'sum'),
(tf.reduce_prod, 'prod'),
(tf.reduce_all, 'all'),
]
for source_fun, target_fun_name in _FUN_NAMES:
TensorflowBackend.register_method(target_fun_name, source_fun)
del _FUN_NAMES
Computing file changes ...