Revision ac62d292d7de9b49855b7ca53f2645a10c63246a authored by rballester on 16 September 2022, 15:17:15 UTC, committed by rballester on 16 September 2022, 15:17:15 UTC
1 parent 363d461
ops.py
import tntorch as tn
import torch
def cumsum(t, dim=None):
"""
Computes the cumulative sum of a tensor along one or several dims, similarly to PyTorch's `cumsum()`.
:param t: input :class:`Tensor`
:param dim: an int or list of ints (default: all)
:return: a :class:`Tensor` of the same shape
"""
if dim is None:
dim = range(t.dim())
if not hasattr(dim, '__len__'):
dim = [dim]
t = t.clone()
for n in dim:
if t.Us[n] is None:
t.cores[n] = torch.cumsum(t.cores[n], dim=-2)
else:
if t.batch:
t.Us[n] = torch.cumsum(t.Us[n], dim=1)
else:
t.Us[n] = torch.cumsum(t.Us[n], dim=0)
return t
def cumprod(t, dim=None):
"""
Computes the cumulative sum of a tensor along one or several dims, similarly to PyTorch's `cumprod()`.
Note: this function is approximate and uses cross-approximation (:func:`tntorch.cross()`)
:param t: input :class:`Tensor`
:param dim: an int or list of ints (default: all)
:return: a :class:`Tensor` of the same shape
"""
return tn.exp(tn.cumsum(tn.log(t), dim=dim))
"""
Unary operations (using cross-approximation)
"""
def abs(t):
"""
Element-wise absolute value computed using cross-approximation; see PyTorch's `abs()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.abs(x), tensors=t, verbose=False)
def acos(t):
"""
Element-wise arccosine computed using cross-approximation; see PyTorch's `acos()`.
:param t: input :class:`Tensor`s
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.acos(x), tensors=t, verbose=False)
def asin(t):
"""
Element-wise arcsine computed using cross-approximation; see PyTorch's `asin()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.asin(x), tensors=t, verbose=False)
def cos(t):
"""
Element-wise cosine computed using cross-approximation; see PyTorch's `cos()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.cos(x), tensors=t, verbose=False)
def cosh(t):
"""
Element-wise hyperbolic cosine computed using cross-approximation; see PyTorch's `cosh()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.cosh(x), tensors=t, verbose=False)
def erf(t):
"""
Element-wise error function computed using cross-approximation; see PyTorch's `erf()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.erf(x), tensors=t, verbose=False)
def erfinv(t):
"""
Element-wise inverse error function computed using cross-approximation; see PyTorch's `erfinv()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.erfinv(x), tensors=t, verbose=False)
def exp(t):
"""
Element-wise exponentiation computed using cross-approximation; see PyTorch's `exp()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.exp(x), tensors=t, verbose=False)
def log(t):
"""
Element-wise natural logarithm computed using cross-approximation; see PyTorch's `log()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.log(x), tensors=t, verbose=False)
def log10(t):
"""
Element-wise base-10 logarithm computed using cross-approximation; see PyTorch's `log10()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.log10(x), tensors=t, verbose=False)
def log2(t):
"""
Element-wise base-2 logarithm computed using cross-approximation; see PyTorch's `log2()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.log2(x), tensors=t, verbose=False)
def reciprocal(t):
"""
Element-wise reciprocal computed using cross-approximation; see PyTorch's `reciprocal()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.reciprocal(x), tensors=t, verbose=False)
def rsqrt(t):
"""
Element-wise square-root reciprocal computed using cross-approximation; see PyTorch's `rsqrt()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.rsqrt(x), tensors=t, verbose=False)
def sigmoid(t):
"""
Element-wise sigmoid computed using cross-approximation; see PyTorch's `igmoid()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.sigmoid(x), tensors=t, verbose=False)
def sin(t):
"""
Element-wise sine computed using cross-approximation; see PyTorch's `in()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.sin(x), tensors=t, verbose=False)
def sinh(t):
"""
Element-wise hyperbolic sine computed using cross-approximation; see PyTorch's `inh()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.sinh(x), tensors=t, verbose=False)
def sqrt(t):
"""
Element-wise square root computed using cross-approximation; see PyTorch's `qrt()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.sqrt(x), tensors=t, verbose=False)
def tan(t):
"""
Element-wise tangent computed using cross-approximation; see PyTorch's `tan()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.tan(x), tensors=t, verbose=False)
def tanh(t):
"""
Element-wise hyperbolic tangent computed using cross-approximation; see PyTorch's `tanh()`.
:param t: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x: torch.tanh(x), tensors=t, verbose=False)
"""
Binary operations (using cross-approximation)
"""
def add(t1, t2):
"""
Element-wise addition computed using cross-approximation; see PyTorch's `add()`.
:param t1: input :class:`Tensor`
:param t2: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x, y: torch.add(x, y), tensors=[t1, t2], verbose=False)
def atan2(t1, t2):
"""
Element-wise arctangent computed using cross-approximation; see PyTorch's `atan2()`.
:param t1: input :class:`Tensor`
:param t2: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x, y: torch.atan2(x, y), tensors=[t1, t2], verbose=False)
def div(t1, t2):
"""
Element-wise division computed using cross-approximation; see PyTorch's `div()`.
:param t1: input :class:`Tensor`
:param t2: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return t1 / t2
def mul(t1, t2):
"""
Element-wise product computed using cross-approximation; see PyTorch's `mul()`.
:param t1: input :class:`Tensor`
:param t2: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return tn.cross(lambda x, y: torch.mul(x, y), tensors=[t1, t2], verbose=False)
def pow(t1, t2):
"""
Element-wise power operation computed using cross-approximation; see PyTorch's `pow()`.
:param t1: input :class:`Tensor`
:param t2: input :class:`Tensor`
:return: a :class:`Tensor`
"""
return t1**t2
Computing file changes ...