Raw File
import tntorch as tn
import torch


def cumsum(t, dim=None):
    """
    Computes the cumulative sum of a tensor along one or several dims, similarly to PyTorch's `cumsum()`.

    :param t: input :class:`Tensor`
    :param dim: an int or list of ints (default: all)

    :return: a :class:`Tensor` of the same shape
    """

    if dim is None:
        dim = range(t.dim())
    if not hasattr(dim, '__len__'):
        dim = [dim]

    t = t.clone()
    for n in dim:
        if t.Us[n] is None:
            t.cores[n] = torch.cumsum(t.cores[n], dim=-2)
        else:
            if t.batch:
                t.Us[n] = torch.cumsum(t.Us[n], dim=1)
            else:
                t.Us[n] = torch.cumsum(t.Us[n], dim=0)
    return t


def cumprod(t, dim=None):
    """
    Computes the cumulative sum of a tensor along one or several dims, similarly to PyTorch's `cumprod()`.

    Note: this function is approximate and uses cross-approximation (:func:`tntorch.cross()`)

    :param t: input :class:`Tensor`
    :param dim: an int or list of ints (default: all)

    :return: a :class:`Tensor` of the same shape
    """

    return tn.exp(tn.cumsum(tn.log(t), dim=dim))


"""
Unary operations (using cross-approximation)
"""


def abs(t):
    """
    Element-wise absolute value computed using cross-approximation; see PyTorch's `abs()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.abs(x), tensors=t, verbose=False)


def acos(t):
    """
    Element-wise arccosine computed using cross-approximation; see PyTorch's `acos()`.

    :param t: input :class:`Tensor`s

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.acos(x), tensors=t, verbose=False)


def asin(t):
    """
    Element-wise arcsine computed using cross-approximation; see PyTorch's `asin()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.asin(x), tensors=t, verbose=False)


def cos(t):
    """
    Element-wise cosine computed using cross-approximation; see PyTorch's `cos()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.cos(x), tensors=t, verbose=False)


def cosh(t):
    """
    Element-wise hyperbolic cosine computed using cross-approximation; see PyTorch's `cosh()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.cosh(x), tensors=t, verbose=False)


def erf(t):
    """
    Element-wise error function computed using cross-approximation; see PyTorch's `erf()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.erf(x), tensors=t, verbose=False)


def erfinv(t):
    """
    Element-wise inverse error function computed using cross-approximation; see PyTorch's `erfinv()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.erfinv(x), tensors=t, verbose=False)


def exp(t):
    """
    Element-wise exponentiation computed using cross-approximation; see PyTorch's `exp()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.exp(x), tensors=t, verbose=False)


def log(t):
    """
    Element-wise natural logarithm computed using cross-approximation; see PyTorch's `log()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.log(x), tensors=t, verbose=False)


def log10(t):
    """
    Element-wise base-10 logarithm computed using cross-approximation; see PyTorch's `log10()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.log10(x), tensors=t, verbose=False)


def log2(t):
    """
    Element-wise base-2 logarithm computed using cross-approximation; see PyTorch's `log2()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.log2(x), tensors=t, verbose=False)


def reciprocal(t):
    """
    Element-wise reciprocal computed using cross-approximation; see PyTorch's `reciprocal()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.reciprocal(x), tensors=t, verbose=False)


def rsqrt(t):
    """
    Element-wise square-root reciprocal computed using cross-approximation; see PyTorch's `rsqrt()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.rsqrt(x), tensors=t, verbose=False)


def sigmoid(t):
    """
    Element-wise sigmoid computed using cross-approximation; see PyTorch's `igmoid()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.sigmoid(x), tensors=t, verbose=False)


def sin(t):
    """
    Element-wise sine computed using cross-approximation; see PyTorch's `in()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.sin(x), tensors=t, verbose=False)


def sinh(t):
    """
    Element-wise hyperbolic sine computed using cross-approximation; see PyTorch's `inh()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.sinh(x), tensors=t, verbose=False)


def sqrt(t):
    """
    Element-wise square root computed using cross-approximation; see PyTorch's `qrt()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.sqrt(x), tensors=t, verbose=False)


def tan(t):
    """
    Element-wise tangent computed using cross-approximation; see PyTorch's `tan()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.tan(x), tensors=t, verbose=False)


def tanh(t):
    """
    Element-wise hyperbolic tangent computed using cross-approximation; see PyTorch's `tanh()`.

    :param t: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x: torch.tanh(x), tensors=t, verbose=False)


"""
Binary operations (using cross-approximation)
"""


def add(t1, t2):
    """
    Element-wise addition computed using cross-approximation; see PyTorch's `add()`.

    :param t1: input :class:`Tensor`
    :param t2: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x, y: torch.add(x, y), tensors=[t1, t2], verbose=False)


def atan2(t1, t2):
    """
    Element-wise arctangent computed using cross-approximation; see PyTorch's `atan2()`.

    :param t1: input :class:`Tensor`
    :param t2: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x, y: torch.atan2(x, y), tensors=[t1, t2], verbose=False)


def div(t1, t2):
    """
    Element-wise division computed using cross-approximation; see PyTorch's `div()`.

    :param t1: input :class:`Tensor`
    :param t2: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return t1 / t2


def mul(t1, t2):
    """
    Element-wise product computed using cross-approximation; see PyTorch's `mul()`.

    :param t1: input :class:`Tensor`
    :param t2: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return tn.cross(lambda x, y: torch.mul(x, y), tensors=[t1, t2], verbose=False)


def pow(t1, t2):
    """
    Element-wise power operation computed using cross-approximation; see PyTorch's `pow()`.

    :param t1: input :class:`Tensor`
    :param t2: input :class:`Tensor`

    :return: a :class:`Tensor`
    """

    return t1**t2
back to top