https://github.com/rballester/tntorch
Raw File
Tip revision: 3af563a42794ba169e7902198d1edd919617a958 authored by Rafael Ballester on 16 March 2023, 15:48:54 UTC
Updated doc (ranks_cp actually must be an integer, not a list)
Tip revision: 3af563a
cross.py
import tntorch as tn
import torch
import sys
import time
import numpy as np
import logging
from typing import Any, Callable, Sequence, Union


def minimum(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
    """
    Estimate the minimal element of a tensor (or a function of one or several tensors)

    :param t: input :class:`Tensor`
    :param rmax: used for :func:`cross.cross()`. Lower is faster; higher is more accurate (default is 10)
    :param max_iter: used for :func:`cross.cross()`. Lower is faster; higher is more accurate (default is 10)
    :param verbose: default is False
    :param **kwargs: passed to :func:`cross.cross()`

    :return: a scalar
    """

    _, info = cross(
        **kwargs,
        tensors=tensors,
        function=function,
        rmax=rmax,
        max_iter=max_iter,
        verbose=verbose,
        return_info=True,
        _minimize=True)
    return info['min']


def argmin(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
    """
    Estimate the minimizer of a tensor (position where its minimum is located).

    For arguments, see :func:`cross.minimum()`

    :return: a tuple
    """

    _, info = cross(
        **kwargs,
        tensors=tensors,
        function=function,
        rmax=rmax,
        max_iter=max_iter,
        verbose=verbose,
        return_info=True,
        _minimize=True)
    return info['argmin']


def maximum(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
    """
    Estimate the maximal element of a tensor.

    For arguments, see :func:`cross.minimum()`

    :return: a scalar
    """

    _, info = cross(
        **kwargs,
        function=lambda *x: -function(*x),
        tensors=tensors,
        rmax=rmax,
        max_iter=max_iter,
        verbose=verbose,
        return_info=True,
        _minimize=True)
    return -info['min']


def argmax(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
    """
    Estimate the maximizer of a tensor (position where its maximum is located).

    For arguments, see :func:`cross.minimum()`

    :return: a tuple
    """

    _, info = cross(
        **kwargs,
        tensors=tensors,
        function=lambda *x: -function(*x),
        rmax=rmax,
        max_iter=max_iter,
        verbose=verbose,
        return_info=True,
        _minimize=True)
    return info['argmin']


# Initialize left and right interfaces for `tensors`
def init_interfaces(tensors, rsets, N, device):
    t_linterfaces = []
    t_rinterfaces = []
    for t in tensors:
        linterfaces = [torch.ones(1, t.ranks_tt[0]).to(device)] + [None] * (N - 1)
        rinterfaces = [None] * (N - 1) + [torch.ones(t.ranks_tt[t.dim()], 1).to(device)]
        for j in range(N - 1):
            M = torch.ones(t.cores[-1].shape[-1], len(rsets[j])).to(device)
            for n in range(N - 1, j, -1):
                if t.cores[n].dim() == 3:  # TT core
                    M = torch.einsum('iaj,ja->ia', [t.cores[n][:, rsets[j][:, n - 1 - j], :].to(device), M])
                else:  # CP factor
                    M = torch.einsum('ai,ia->ia', [t.cores[n][rsets[j][:, n - 1 - j], :].to(device), M])
            rinterfaces[j] = M
        t_linterfaces.append(linterfaces)
        t_rinterfaces.append(rinterfaces)
    return t_linterfaces, t_rinterfaces


def cross(
        function: Callable = lambda x: x,
        domain=None,
        tensors: Union[Any, Sequence[Any]] = None,
        function_arg: str = 'vectors',
        ranks_tt: Union[int, Sequence[int]] = None,
        kickrank: int = 3,
        rmax: int = 100,
        eps: float = 1e-6,
        max_iter: int = 25,
        val_size: int = 1000,
        verbose: bool = True,
        return_info: bool = False,
        record_samples: bool = False,
        _minimize: bool = False,
        device: Any = None,
        suppress_warnings: bool = False,
        detach_evaluations: bool = False):

    """
    Cross-approximation routine that samples a black-box function and returns an N-dimensional tensor train approximating it. It accepts either:

    - A domain (tensor product of :math:`N` given arrays) and a function :math:`\\mathbb{R}^N \\to \\mathbb{R}`
    - A list of :math:`K` tensors of dimension :math:`N` and equal shape and a function :math:`\\mathbb{R}^K \\to \\mathbb{R}`

    :Examples:

    >>> tn.cross(function=lambda x: x**2, tensors=[t])  # Compute the element-wise square of `t` using 5 TT-ranks

    >>> domain = [torch.linspace(-1, 1, 32)]*5
    >>> tn.cross(function=lambda x, y, z, t, w: x**2 + y*z + torch.cos(t + w), domain=domain)  # Approximate a function over the rectangle :math:`[-1, 1]^5`

    >>> tn.cross(function=lambda x: torch.sum(x**2, dim=1), domain=domain, function_arg='matrix')  # An example where the function accepts a matrix

    References:

    - I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
    - D. Savostyanov, I. Oseledets: `"Fast Adaptive Interpolation of Multi-dimensional Arrays in Tensor Train Format" (2011) <https://ieeexplore.ieee.org/document/6076873>`_
    - S. Dolgov, R. Scheichl: `"A Hybrid Alternating Least Squares - TT Cross Algorithm for Parametric PDEs" (2018) <https://arxiv.org/pdf/1707.04562.pdf>`_
    - A. Mikhalev's `maxvolpy package <https://bitbucket.org/muxas/maxvolpy>`_
    - I. Oseledets (and others)'s `ttpy package <https://github.com/oseledets/ttpy>`_

    :param function: should produce a vector of :math:`P` elements. Accepts either :math:`N` comma-separated vectors, or a matrix (see `function_arg`)
    :param domain: a list of :math:`N` vectors (incompatible with `tensors`)
    :param tensors: a :class:`Tensor` or list thereof (incompatible with `domain`)
    :param function_arg: if 'vectors', `function` accepts :math:`N` vectors of length :math:`P` each. If 'matrix', a matrix of shape :math:`P \\times N`.
    :param ranks_tt: int or list of :math:`N-1` ints. If None, will be determined adaptively
    :param kickrank: when adaptively found, ranks will be increased by this amount after every iteration (full sweep left-to-right and right-to-left)
    :param rmax: this rank will not be surpassed
    :param eps: the procedure will stop after this validation error is met (as measured after each iteration)
    :param max_iter: int
    :param val_size: size of the validation set
    :param verbose: default is True
    :param return_info: if True, will also return a dictionary with informative metrics about the algorithm's outcome
    :param device: PyTorch device
    :param suppress_warnings: Boolean, if True, will hide the message about insufficient accuracy
    :param detach_evaluations: Boolean, if True, will remove gradient buffers for the function

    :return: an N-dimensional TT :class:`Tensor` (if `return_info`=True, also a dictionary)
    """
    if device is None and tensors is not None:
        if type(tensors) == list:
            device = tensors[0].cores[0].device
        else:
            device = tensors.cores[0].device

    if verbose:
        print('cross device is', device)

    try:
        import maxvolpy.maxvol
        maxvol = maxvolpy.maxvol.maxvol
        rect_maxvol = maxvolpy.maxvol.rect_maxvol
    except ModuleNotFoundError:
        print(
            "Functions that require cross-approximation can be accelerated with the optional maxvolpy package," +
            " which can be installed by 'pip install maxvolpy'. " +
            "More info is available at https://bitbucket.org/muxas/maxvolpy.")
        from tntorch.maxvol import py_maxvol, py_rect_maxvol
        maxvol = py_maxvol
        rect_maxvol = py_rect_maxvol

    assert domain is not None or tensors is not None
    assert function_arg in ('vectors', 'matrix')
    if function_arg == 'matrix':
        def f(*args):
            return function(torch.cat([arg[:, None] for arg in args], dim=1))
    else:
        f = function

    if detach_evaluations:
        def build_function_wrapper(func):
            def g(*args):
                res = func(*args)
                if hasattr(res, '__len__') and not isinstance(res, torch.Tensor):
                    for i in range(len(res)):
                        if isinstance(res[i], torch.Tensor):
                            res[i] = res[i].detach()
                else:
                    if isinstance(res, torch.Tensor):
                        res = res.detach()
                return res
            return g

        f = build_function_wrapper(f)

    if tensors is None:
        tensors = tn.meshgrid(domain)

    if not hasattr(tensors, '__len__'):
        tensors = [tensors]
    for t in tensors:
        if t.batch:
            raise ValueError('Batched tensors are not supported.')
    tensors = [t.decompress_tucker_factors(_clone=False) for t in tensors]
    Is = list(tensors[0].shape)
    N = len(Is)

    # Process ranks and cap them, if needed
    if ranks_tt is None:
        ranks_tt = 1
    else:
        kickrank = None
    if not hasattr(ranks_tt, '__len__'):
        ranks_tt = [ranks_tt] * (N - 1)
    ranks_tt = [1] + list(ranks_tt) + [1]
    Rs = np.array(ranks_tt)

    for n in list(range(1, N)) + list(range(N - 1, -1, -1)):
        Rs[n] = min(Rs[n - 1] * Is[n - 1], Rs[n], Is[n] * Rs[n + 1])

    # Initialize cores at random
    cores = [torch.randn(Rs[n], Is[n], Rs[n + 1]).to(device) for n in range(N)]

    # Prepare left and right sets
    lsets = [np.array([[0]])] + [None] * (N - 1)
    randint = np.hstack([np.random.randint(0, Is[n + 1], [max(Rs), 1]) for n in range(N - 1)] + [np.zeros([max(Rs), 1], dtype=int)])
    rsets = [randint[:Rs[n + 1], n:] for n in range(N - 1)] + [np.array([[0]])]

    t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device)

    # Create a validation set
    Xs_val = [torch.as_tensor(np.random.choice(I, int(val_size))).to(device) for I in Is]
    ys_val = f(*[t[Xs_val].torch() for t in tensors])
    if ys_val.dim() > 1:
        assert ys_val.dim() == 2
        assert ys_val.shape[1] == 1
        ys_val = ys_val[:, 0]

    assert len(ys_val) == val_size
    norm_ys_val = torch.norm(ys_val)

    if verbose:
        print('Cross-approximation over a {}D domain containing {:g} grid points:'.format(N, tensors[0].numel()))
    start = time.time()
    converged = False

    info = {
        'nsamples': 0,
        'eval_time': 0,
        'val_epss': [],
        'min': 0,
        'argmin': None
    }
    if record_samples:
        info['sample_positions'] = torch.zeros(0, N).to(device)
        info['sample_values'] = torch.zeros(0).to(device)

    def evaluate_function(j):  # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
        Xs = []
        for k, t in enumerate(tensors):
            if tensors[k].cores[j].dim() == 3:  # TT core
                V = torch.einsum('ai,ibj,jc->abc', [t_linterfaces[k][j], t.cores[j], t_rinterfaces[k][j]])
            else:  # CP factor
                V = torch.einsum('ai,bi,ic->abc', [t_linterfaces[k][j], t.cores[j], t_rinterfaces[k][j]])
            Xs.append(V.flatten())

        eval_start = time.time()
        evaluation = f(*Xs)
        if record_samples:
            info['sample_positions'] = torch.cat((info['sample_positions'], torch.cat([x[:, None] for x in Xs], dim=1)), dim=0)
            info['sample_values'] = torch.cat((info['sample_values'], evaluation))
        info['eval_time'] += time.time() - eval_start
        if _minimize:
            evaluation = np.pi / 2 - torch.atan((evaluation - info['min']))  # Function used by I. Oseledets for TT minimization in ttpy
            evaluation_argmax = torch.argmax(evaluation)
            eval_min = torch.tan(np.pi / 2 - evaluation[evaluation_argmax]) + info['min']
            if info['min'] == 0 or eval_min < info['min']:
                coords = np.unravel_index(evaluation_argmax.cpu(), [Rs[j], Is[j], Rs[j + 1]])
                info['min'] = eval_min
                info['argmin'] = tuple(lsets[j][coords[0]][1:]) + tuple([coords[1]]) + tuple(rsets[j][coords[2]][:-1])

        # Check for nan/inf values
        if evaluation.dim() == 2:
            evaluation = evaluation[:, 0]
        invalid = torch.nonzero(torch.isnan(evaluation) | torch.isinf(evaluation))
        if len(invalid) > 0:
            invalid = invalid[0].item()
            raise ValueError(
                'Invalid return value for function {}: f({}) = {}'.format(
                    function,
                    ', '.join('{:g}'.format(x[invalid].detach().cpu().numpy()) for x in Xs),
                    f(*[x[invalid:invalid + 1][:, None] for x in Xs]).item()))

        V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
        info['nsamples'] += V.numel()
        return V

    # Sweeps
    for i in range(max_iter):

        if verbose:
            print('iter: {: <{}}'.format(i, len('{}'.format(max_iter)) + 1), end='')
            sys.stdout.flush()

        left_locals = []

        # Left-to-right
        for j in range(N - 1):

            # Update tensors for current indices
            V = evaluate_function(j)

            # QR + maxvol towards the right
            V = torch.reshape(V, [-1, Rs[j + 1]])  # Left unfolding
            Q, _ = torch.linalg.qr(V)
            if _minimize:
                local, _ = rect_maxvol(Q.detach().cpu().numpy(), maxK=Q.shape[1])
            else:
                local, _ = maxvol(Q.detach().cpu().numpy())
            V = torch.linalg.lstsq(Q[local, :].t(), Q.t()).solution.t()
            cores[j] = torch.reshape(V, [Rs[j], Is[j], Rs[j + 1]])
            left_locals.append(local)

            # Map local indices to global ones
            local_r, local_i = np.unravel_index(local, [Rs[j], Is[j]])
            lsets[j+1] = np.c_[lsets[j][local_r, :], local_i]
            for k, t in enumerate(tensors):
                if t.cores[j].dim() == 3:  # TT core
                    t_linterfaces[k][j+1] = torch.einsum('ai,iaj->aj', [t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]])
                else:  # CP factor
                    t_linterfaces[k][j+1] = torch.einsum('ai,ai->ai', [t_linterfaces[k][j][local_r, :], t.cores[j][local_i, :]])

        # Right-to-left sweep
        for j in range(N - 1, 0, -1):

            # Update tensors for current indices
            V = evaluate_function(j)

            # QR + maxvol towards the left
            V = torch.reshape(V, [Rs[j], -1])  # Right unfolding
            Q, _ = torch.linalg.qr(V.t())
            if _minimize:
                local, _ = rect_maxvol(Q.detach().cpu().numpy(), maxK=Q.shape[1])
            else:
                local, _ = maxvol(Q.detach().cpu().numpy())
            V = torch.linalg.lstsq(Q[local, :].t(), Q.t()).solution
            cores[j] = torch.reshape(torch.as_tensor(V), [Rs[j], Is[j], Rs[j+1]])

            # Map local indices to global ones
            local_i, local_r = np.unravel_index(local, [Is[j], Rs[j + 1]])
            rsets[j - 1] = np.c_[local_i, rsets[j][local_r, :]]
            for k, t in enumerate(tensors):
                if t.cores[j].dim() == 3:  # TT core
                    t_rinterfaces[k][j-1] = torch.einsum('iaj,ja->ia', [t.cores[j][:, local_i, :], t_rinterfaces[k][j][:, local_r]])
                else:  # CP factor
                    t_rinterfaces[k][j-1] = torch.einsum('ai,ia->ia', [t.cores[j][local_i, :], t_rinterfaces[k][j][:, local_r]])

        # Leave the first core ready
        V = evaluate_function(0)
        cores[0] = V

        # Evaluate validation error
        val_eps = torch.norm(ys_val - tn.Tensor(cores)[Xs_val].torch()) / norm_ys_val
        info['val_epss'].append(val_eps)
        if val_eps < eps:
            converged = True

        if verbose:  # Print status
            if _minimize:
                print('| best: {:.8g}'.format(info['min']), end='')
            else:
                print('| eps: {:.3e}'.format(val_eps), end='')
            print(' | time: {:8.4f} | largest rank: {:3d}'.format(time.time() - start, max(Rs)), end='')
            if converged:
                print(' <- converged: eps < {}'.format(eps))
            elif i == max_iter-1:
                print(' <- max_iter was reached: {}'.format(max_iter))
            else:
                print()
        if converged:
            break
        elif i < max_iter - 1 and kickrank is not None:  # Augment ranks
            newRs = Rs.copy()
            newRs[1:-1] = np.minimum(rmax, newRs[1:-1] + kickrank)
            for n in list(range(1, N)) + list(range(N - 1, 0, -1)):
                newRs[n] = min(newRs[n - 1] * Is[n - 1], newRs[n], Is[n] * newRs[n + 1])
            extra = np.hstack([np.random.randint(0, Is[n + 1], [max(newRs), 1]) for n in range(N - 1)] + [np.zeros([max(newRs), 1], dtype=int)])
            for n in range(N - 1):
                if newRs[n + 1] > Rs[n + 1]:
                    rsets[n] = np.vstack([rsets[n], extra[:newRs[n + 1]-Rs[n + 1], n:]])
            Rs = newRs
            t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device)  # Recompute interfaces

    if val_eps > eps and not _minimize and not suppress_warnings:
        logging.warning('eps={:g} (larger than {}) when cross-approximating {}'.format(val_eps, eps, function))

    if verbose:
        print('Did {} function evaluations, which took {:.4g}s ({:.4g} evals/s)'.format(
            info['nsamples'], info['eval_time'], info['nsamples'] / info['eval_time']))
        print()

    ret = tn.Tensor([c if isinstance(c, torch.Tensor) else torch.tensor(c) for c in cores])
    if return_info:
        info['lsets'] = lsets
        info['rsets'] = rsets
        info['Rs'] = Rs
        info['left_locals'] = left_locals
        info['total_time'] = time.time() - start
        info['val_eps'] = val_eps
        return ret, info
    else:
        return ret


def cross_forward(
        info,
        function=lambda x: x,
        domain=None,
        tensors=None,
        function_arg='vectors',
        return_info=False):
    """
    Given TT-cross indices and a black-box function (to be evaluated on an arbitrary grid), computes a differentiable TT tensor as given by the TT-cross interpolation formula.
    Reference: I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
    :param info: dictionary with the indices returned by `tntorch.cross()`
    :param function: a function $\mathbb{R}^M \to \mathbb{R}$, as in `tntorch.cross()`
    :param domain: domain where `function` will be evaluated on, as in `tntorch.cross()`
    :param tensors: list of $M$ TT tensors where `function` will be evaluated on
    :param function_arg: type of argument accepted by `function`. See `tntorch.cross()`
    :param return_info: Boolean, if True, will also return a dictionary with informative metrics about the algorithm's outcome
    :return: a TT :class:`Tensor`(if `return_info`=True, also a dictionary)
    """

    assert domain is not None or tensors is not None
    assert function_arg in ('vectors', 'matrix')
    device = None
    if function_arg == 'matrix':
        def f(*args):
            return function(torch.cat([arg[:, None] for arg in args], dim=1))
    else:
        f = function
    if tensors is None:
        tensors = tn.meshgrid(domain)
        device = domain[0].device
    if not hasattr(tensors, '__len__'):
        tensors = [tensors]

    Is = list(tensors[0].shape)
    N = len(Is)

    # Load index information from dictionary
    lsets = info['lsets']
    rsets = info['rsets']
    left_locals = info['left_locals']
    Rs = info['Rs']

    if return_info:
        info['Xs'] = torch.zeros(0, N)
        info['shapes'] = []

    assert function_arg in ('vectors', 'matrix')
    if function_arg == 'matrix':
        def f(*args):
            return function(torch.cat([arg[:, None] for arg in args], dim=1))
    else:
        f = function

    t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device)

    def evaluate_function(j):  # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
        Xs = []
        for k, t in enumerate(tensors):
            V = torch.einsum('ai,ibj,jc->abc', [t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j]])
            Xs.append(V.flatten())

        evaluation = f(*Xs)

        if return_info:
            info['Xs'] = torch.cat((info['Xs'], torch.cat([x[:, None] for x in Xs], dim=1).detach().cpu()), dim=0)
            info['shapes'].append([Rs[j], Is[j], Rs[j + 1]])

        V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
        return V

    cores = []

    # Cross-interpolation formula, left-to-right
    for j in range(0, N-1):

        # Update tensors for current indices
        V = evaluate_function(j)
        V = torch.reshape(V, [-1, V.shape[2]])  # Left unfolding
        A = V[left_locals[j], :]
        X = torch.linalg.lstsq(A.t(), V.t()).solution.t()

        cores.append(X.reshape(Rs[j], Is[j], Rs[j + 1]))

        # Map local indices to global ones
        local_r, local_i = np.unravel_index(left_locals[j], [Rs[j], Is[j]])
        lsets[j + 1] = np.c_[lsets[j][local_r, :], local_i]
        for k, t in enumerate(tensors):
            t_linterfaces[k][j + 1] = torch.einsum('ai,iaj->aj',
                                                       [t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]])

    # Leave the last core ready
    X = evaluate_function(N-1)
    cores.append(X)

    if return_info:
        return tn.Tensor(cores), info
    else:
        return tn.Tensor(cores)
back to top