Revision ac62d292d7de9b49855b7ca53f2645a10c63246a authored by rballester on 16 September 2022, 15:17:15 UTC, committed by rballester on 16 September 2022, 15:17:15 UTC
1 parent 363d461
cross.py
import tntorch as tn
import torch
import sys
import time
import numpy as np
import logging
from typing import Any, Callable, Sequence, Union
def minimum(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
"""
Estimate the minimal element of a tensor (or a function of one or several tensors)
:param t: input :class:`Tensor`
:param rmax: used for :func:`cross.cross()`. Lower is faster; higher is more accurate (default is 10)
:param max_iter: used for :func:`cross.cross()`. Lower is faster; higher is more accurate (default is 10)
:param verbose: default is False
:param **kwargs: passed to :func:`cross.cross()`
:return: a scalar
"""
_, info = cross(
**kwargs,
tensors=tensors,
function=function,
rmax=rmax,
max_iter=max_iter,
verbose=verbose,
return_info=True,
_minimize=True)
return info['min']
def argmin(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
"""
Estimate the minimizer of a tensor (position where its minimum is located).
For arguments, see :func:`cross.minimum()`
:return: a tuple
"""
_, info = cross(
**kwargs,
tensors=tensors,
function=function,
rmax=rmax,
max_iter=max_iter,
verbose=verbose,
return_info=True,
_minimize=True)
return info['argmin']
def maximum(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
"""
Estimate the maximal element of a tensor.
For arguments, see :func:`cross.minimum()`
:return: a scalar
"""
_, info = cross(
**kwargs,
function=lambda *x: -function(*x),
tensors=tensors,
rmax=rmax,
max_iter=max_iter,
verbose=verbose,
return_info=True,
_minimize=True)
return -info['min']
def argmax(tensors=None, function=lambda x: x, rmax=10, max_iter=10, verbose=False, **kwargs):
"""
Estimate the maximizer of a tensor (position where its maximum is located).
For arguments, see :func:`cross.minimum()`
:return: a tuple
"""
_, info = cross(
**kwargs,
tensors=tensors,
function=lambda *x: -function(*x),
rmax=rmax,
max_iter=max_iter,
verbose=verbose,
return_info=True,
_minimize=True)
return info['argmin']
# Initialize left and right interfaces for `tensors`
def init_interfaces(tensors, rsets, N, device):
t_linterfaces = []
t_rinterfaces = []
for t in tensors:
linterfaces = [torch.ones(1, t.ranks_tt[0]).to(device)] + [None] * (N - 1)
rinterfaces = [None] * (N - 1) + [torch.ones(t.ranks_tt[t.dim()], 1).to(device)]
for j in range(N - 1):
M = torch.ones(t.cores[-1].shape[-1], len(rsets[j])).to(device)
for n in range(N - 1, j, -1):
if t.cores[n].dim() == 3: # TT core
M = torch.einsum('iaj,ja->ia', [t.cores[n][:, rsets[j][:, n - 1 - j], :].to(device), M])
else: # CP factor
M = torch.einsum('ai,ia->ia', [t.cores[n][rsets[j][:, n - 1 - j], :].to(device), M])
rinterfaces[j] = M
t_linterfaces.append(linterfaces)
t_rinterfaces.append(rinterfaces)
return t_linterfaces, t_rinterfaces
def cross(
function: Callable = lambda x: x,
domain=None,
tensors: Union[Any, Sequence[Any]] = None,
function_arg: str = 'vectors',
ranks_tt: Union[int, Sequence[int]] = None,
kickrank: int = 3,
rmax: int = 100,
eps: float = 1e-6,
max_iter: int = 25,
val_size: int = 1000,
verbose: bool = True,
return_info: bool = False,
record_samples: bool = False,
_minimize: bool = False,
device: Any = None,
suppress_warnings: bool = False,
detach_evaluations: bool = False):
"""
Cross-approximation routine that samples a black-box function and returns an N-dimensional tensor train approximating it. It accepts either:
- A domain (tensor product of :math:`N` given arrays) and a function :math:`\\mathbb{R}^N \\to \\mathbb{R}`
- A list of :math:`K` tensors of dimension :math:`N` and equal shape and a function :math:`\\mathbb{R}^K \\to \\mathbb{R}`
:Examples:
>>> tn.cross(function=lambda x: x**2, tensors=[t]) # Compute the element-wise square of `t` using 5 TT-ranks
>>> domain = [torch.linspace(-1, 1, 32)]*5
>>> tn.cross(function=lambda x, y, z, t, w: x**2 + y*z + torch.cos(t + w), domain=domain) # Approximate a function over the rectangle :math:`[-1, 1]^5`
>>> tn.cross(function=lambda x: torch.sum(x**2, dim=1), domain=domain, function_arg='matrix') # An example where the function accepts a matrix
References:
- I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
- D. Savostyanov, I. Oseledets: `"Fast Adaptive Interpolation of Multi-dimensional Arrays in Tensor Train Format" (2011) <https://ieeexplore.ieee.org/document/6076873>`_
- S. Dolgov, R. Scheichl: `"A Hybrid Alternating Least Squares - TT Cross Algorithm for Parametric PDEs" (2018) <https://arxiv.org/pdf/1707.04562.pdf>`_
- A. Mikhalev's `maxvolpy package <https://bitbucket.org/muxas/maxvolpy>`_
- I. Oseledets (and others)'s `ttpy package <https://github.com/oseledets/ttpy>`_
:param function: should produce a vector of :math:`P` elements. Accepts either :math:`N` comma-separated vectors, or a matrix (see `function_arg`)
:param domain: a list of :math:`N` vectors (incompatible with `tensors`)
:param tensors: a :class:`Tensor` or list thereof (incompatible with `domain`)
:param function_arg: if 'vectors', `function` accepts :math:`N` vectors of length :math:`P` each. If 'matrix', a matrix of shape :math:`P \\times N`.
:param ranks_tt: int or list of :math:`N-1` ints. If None, will be determined adaptively
:param kickrank: when adaptively found, ranks will be increased by this amount after every iteration (full sweep left-to-right and right-to-left)
:param rmax: this rank will not be surpassed
:param eps: the procedure will stop after this validation error is met (as measured after each iteration)
:param max_iter: int
:param val_size: size of the validation set
:param verbose: default is True
:param return_info: if True, will also return a dictionary with informative metrics about the algorithm's outcome
:param device: PyTorch device
:param suppress_warnings: Boolean, if True, will hide the message about insufficient accuracy
:param detach_evaluations: Boolean, if True, will remove gradient buffers for the function
:return: an N-dimensional TT :class:`Tensor` (if `return_info`=True, also a dictionary)
"""
if device is None and tensors is not None:
if type(tensors) == list:
device = tensors[0].cores[0].device
else:
device = tensors.cores[0].device
if verbose:
print('cross device is', device)
try:
import maxvolpy.maxvol
maxvol = maxvolpy.maxvol.maxvol
rect_maxvol = maxvolpy.maxvol.rect_maxvol
except ModuleNotFoundError:
print(
"Functions that require cross-approximation can be accelerated with the optional maxvolpy package," +
" which can be installed by 'pip install maxvolpy'. " +
"More info is available at https://bitbucket.org/muxas/maxvolpy.")
from tntorch.maxvol import py_maxvol, py_rect_maxvol
maxvol = py_maxvol
rect_maxvol = py_rect_maxvol
assert domain is not None or tensors is not None
assert function_arg in ('vectors', 'matrix')
if function_arg == 'matrix':
def f(*args):
return function(torch.cat([arg[:, None] for arg in args], dim=1))
else:
f = function
if detach_evaluations:
def build_function_wrapper(func):
def g(*args):
res = func(*args)
if hasattr(res, '__len__') and not isinstance(res, torch.Tensor):
for i in range(len(res)):
if isinstance(res[i], torch.Tensor):
res[i] = res[i].detach()
else:
if isinstance(res, torch.Tensor):
res = res.detach()
return res
return g
f = build_function_wrapper(f)
if tensors is None:
tensors = tn.meshgrid(domain)
if not hasattr(tensors, '__len__'):
tensors = [tensors]
for t in tensors:
if t.batch:
raise ValueError('Batched tensors are not supported.')
tensors = [t.decompress_tucker_factors(_clone=False) for t in tensors]
Is = list(tensors[0].shape)
N = len(Is)
# Process ranks and cap them, if needed
if ranks_tt is None:
ranks_tt = 1
else:
kickrank = None
if not hasattr(ranks_tt, '__len__'):
ranks_tt = [ranks_tt] * (N - 1)
ranks_tt = [1] + list(ranks_tt) + [1]
Rs = np.array(ranks_tt)
for n in list(range(1, N)) + list(range(N - 1, -1, -1)):
Rs[n] = min(Rs[n - 1] * Is[n - 1], Rs[n], Is[n] * Rs[n + 1])
# Initialize cores at random
cores = [torch.randn(Rs[n], Is[n], Rs[n + 1]).to(device) for n in range(N)]
# Prepare left and right sets
lsets = [np.array([[0]])] + [None] * (N - 1)
randint = np.hstack([np.random.randint(0, Is[n + 1], [max(Rs), 1]) for n in range(N - 1)] + [np.zeros([max(Rs), 1], dtype=int)])
rsets = [randint[:Rs[n + 1], n:] for n in range(N - 1)] + [np.array([[0]])]
t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device)
# Create a validation set
Xs_val = [torch.as_tensor(np.random.choice(I, int(val_size))).to(device) for I in Is]
ys_val = f(*[t[Xs_val].torch() for t in tensors])
if ys_val.dim() > 1:
assert ys_val.dim() == 2
assert ys_val.shape[1] == 1
ys_val = ys_val[:, 0]
assert len(ys_val) == val_size
norm_ys_val = torch.norm(ys_val)
if verbose:
print('Cross-approximation over a {}D domain containing {:g} grid points:'.format(N, tensors[0].numel()))
start = time.time()
converged = False
info = {
'nsamples': 0,
'eval_time': 0,
'val_epss': [],
'min': 0,
'argmin': None
}
if record_samples:
info['sample_positions'] = torch.zeros(0, N).to(device)
info['sample_values'] = torch.zeros(0).to(device)
def evaluate_function(j): # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
Xs = []
for k, t in enumerate(tensors):
if tensors[k].cores[j].dim() == 3: # TT core
V = torch.einsum('ai,ibj,jc->abc', [t_linterfaces[k][j], t.cores[j], t_rinterfaces[k][j]])
else: # CP factor
V = torch.einsum('ai,bi,ic->abc', [t_linterfaces[k][j], t.cores[j], t_rinterfaces[k][j]])
Xs.append(V.flatten())
eval_start = time.time()
evaluation = f(*Xs)
if record_samples:
info['sample_positions'] = torch.cat((info['sample_positions'], torch.cat([x[:, None] for x in Xs], dim=1)), dim=0)
info['sample_values'] = torch.cat((info['sample_values'], evaluation))
info['eval_time'] += time.time() - eval_start
if _minimize:
evaluation = np.pi / 2 - torch.atan((evaluation - info['min'])) # Function used by I. Oseledets for TT minimization in ttpy
evaluation_argmax = torch.argmax(evaluation)
eval_min = torch.tan(np.pi / 2 - evaluation[evaluation_argmax]) + info['min']
if info['min'] == 0 or eval_min < info['min']:
coords = np.unravel_index(evaluation_argmax.cpu(), [Rs[j], Is[j], Rs[j + 1]])
info['min'] = eval_min
info['argmin'] = tuple(lsets[j][coords[0]][1:]) + tuple([coords[1]]) + tuple(rsets[j][coords[2]][:-1])
# Check for nan/inf values
if evaluation.dim() == 2:
evaluation = evaluation[:, 0]
invalid = torch.nonzero(torch.isnan(evaluation) | torch.isinf(evaluation))
if len(invalid) > 0:
invalid = invalid[0].item()
raise ValueError(
'Invalid return value for function {}: f({}) = {}'.format(
function,
', '.join('{:g}'.format(x[invalid].detach().cpu().numpy()) for x in Xs),
f(*[x[invalid:invalid + 1][:, None] for x in Xs]).item()))
V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
info['nsamples'] += V.numel()
return V
# Sweeps
for i in range(max_iter):
if verbose:
print('iter: {: <{}}'.format(i, len('{}'.format(max_iter)) + 1), end='')
sys.stdout.flush()
left_locals = []
# Left-to-right
for j in range(N - 1):
# Update tensors for current indices
V = evaluate_function(j)
# QR + maxvol towards the right
V = torch.reshape(V, [-1, Rs[j + 1]]) # Left unfolding
Q, _ = torch.linalg.qr(V)
if _minimize:
local, _ = rect_maxvol(Q.detach().cpu().numpy(), maxK=Q.shape[1])
else:
local, _ = maxvol(Q.detach().cpu().numpy())
V = torch.linalg.lstsq(Q[local, :].t(), Q.t()).solution.t()
cores[j] = torch.reshape(V, [Rs[j], Is[j], Rs[j + 1]])
left_locals.append(local)
# Map local indices to global ones
local_r, local_i = np.unravel_index(local, [Rs[j], Is[j]])
lsets[j+1] = np.c_[lsets[j][local_r, :], local_i]
for k, t in enumerate(tensors):
if t.cores[j].dim() == 3: # TT core
t_linterfaces[k][j+1] = torch.einsum('ai,iaj->aj', [t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]])
else: # CP factor
t_linterfaces[k][j+1] = torch.einsum('ai,ai->ai', [t_linterfaces[k][j][local_r, :], t.cores[j][local_i, :]])
# Right-to-left sweep
for j in range(N - 1, 0, -1):
# Update tensors for current indices
V = evaluate_function(j)
# QR + maxvol towards the left
V = torch.reshape(V, [Rs[j], -1]) # Right unfolding
Q, _ = torch.linalg.qr(V.t())
if _minimize:
local, _ = rect_maxvol(Q.detach().cpu().numpy(), maxK=Q.shape[1])
else:
local, _ = maxvol(Q.detach().cpu().numpy())
V = torch.linalg.lstsq(Q[local, :].t(), Q.t()).solution
cores[j] = torch.reshape(torch.as_tensor(V), [Rs[j], Is[j], Rs[j+1]])
# Map local indices to global ones
local_i, local_r = np.unravel_index(local, [Is[j], Rs[j + 1]])
rsets[j - 1] = np.c_[local_i, rsets[j][local_r, :]]
for k, t in enumerate(tensors):
if t.cores[j].dim() == 3: # TT core
t_rinterfaces[k][j-1] = torch.einsum('iaj,ja->ia', [t.cores[j][:, local_i, :], t_rinterfaces[k][j][:, local_r]])
else: # CP factor
t_rinterfaces[k][j-1] = torch.einsum('ai,ia->ia', [t.cores[j][local_i, :], t_rinterfaces[k][j][:, local_r]])
# Leave the first core ready
V = evaluate_function(0)
cores[0] = V
# Evaluate validation error
val_eps = torch.norm(ys_val - tn.Tensor(cores)[Xs_val].torch()) / norm_ys_val
info['val_epss'].append(val_eps)
if val_eps < eps:
converged = True
if verbose: # Print status
if _minimize:
print('| best: {:.8g}'.format(info['min']), end='')
else:
print('| eps: {:.3e}'.format(val_eps), end='')
print(' | time: {:8.4f} | largest rank: {:3d}'.format(time.time() - start, max(Rs)), end='')
if converged:
print(' <- converged: eps < {}'.format(eps))
elif i == max_iter-1:
print(' <- max_iter was reached: {}'.format(max_iter))
else:
print()
if converged:
break
elif i < max_iter - 1 and kickrank is not None: # Augment ranks
newRs = Rs.copy()
newRs[1:-1] = np.minimum(rmax, newRs[1:-1] + kickrank)
for n in list(range(1, N)) + list(range(N - 1, 0, -1)):
newRs[n] = min(newRs[n - 1] * Is[n - 1], newRs[n], Is[n] * newRs[n + 1])
extra = np.hstack([np.random.randint(0, Is[n + 1], [max(newRs), 1]) for n in range(N - 1)] + [np.zeros([max(newRs), 1], dtype=int)])
for n in range(N - 1):
if newRs[n + 1] > Rs[n + 1]:
rsets[n] = np.vstack([rsets[n], extra[:newRs[n + 1]-Rs[n + 1], n:]])
Rs = newRs
t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device) # Recompute interfaces
if val_eps > eps and not _minimize and not suppress_warnings:
logging.warning('eps={:g} (larger than {}) when cross-approximating {}'.format(val_eps, eps, function))
if verbose:
print('Did {} function evaluations, which took {:.4g}s ({:.4g} evals/s)'.format(
info['nsamples'], info['eval_time'], info['nsamples'] / info['eval_time']))
print()
ret = tn.Tensor([c if isinstance(c, torch.Tensor) else torch.tensor(c) for c in cores])
if return_info:
info['lsets'] = lsets
info['rsets'] = rsets
info['Rs'] = Rs
info['left_locals'] = left_locals
info['total_time'] = time.time() - start
info['val_eps'] = val_eps
return ret, info
else:
return ret
def cross_forward(
info,
function=lambda x: x,
domain=None,
tensors=None,
function_arg='vectors',
return_info=False):
"""
Given TT-cross indices and a black-box function (to be evaluated on an arbitrary grid), computes a differentiable TT tensor as given by the TT-cross interpolation formula.
Reference: I. Oseledets, E. Tyrtyshnikov: `"TT-cross Approximation for Multidimensional Arrays" (2009) <http://www.mat.uniroma2.it/~tvmsscho/papers/Tyrtyshnikov5.pdf>`_
:param info: dictionary with the indices returned by `tntorch.cross()`
:param function: a function $\mathbb{R}^M \to \mathbb{R}$, as in `tntorch.cross()`
:param domain: domain where `function` will be evaluated on, as in `tntorch.cross()`
:param tensors: list of $M$ TT tensors where `function` will be evaluated on
:param function_arg: type of argument accepted by `function`. See `tntorch.cross()`
:param return_info: Boolean, if True, will also return a dictionary with informative metrics about the algorithm's outcome
:return: a TT :class:`Tensor`(if `return_info`=True, also a dictionary)
"""
assert domain is not None or tensors is not None
assert function_arg in ('vectors', 'matrix')
device = None
if function_arg == 'matrix':
def f(*args):
return function(torch.cat([arg[:, None] for arg in args], dim=1))
else:
f = function
if tensors is None:
tensors = tn.meshgrid(domain)
device = domain[0].device
if not hasattr(tensors, '__len__'):
tensors = [tensors]
Is = list(tensors[0].shape)
N = len(Is)
# Load index information from dictionary
lsets = info['lsets']
rsets = info['rsets']
left_locals = info['left_locals']
Rs = info['Rs']
if return_info:
info['Xs'] = torch.zeros(0, N)
info['shapes'] = []
assert function_arg in ('vectors', 'matrix')
if function_arg == 'matrix':
def f(*args):
return function(torch.cat([arg[:, None] for arg in args], dim=1))
else:
f = function
t_linterfaces, t_rinterfaces = init_interfaces(tensors, rsets, N, device)
def evaluate_function(j): # Evaluate function over Rs[j] x Rs[j+1] fibers, each of size I[j]
Xs = []
for k, t in enumerate(tensors):
V = torch.einsum('ai,ibj,jc->abc', [t_linterfaces[k][j], tensors[k].cores[j], t_rinterfaces[k][j]])
Xs.append(V.flatten())
evaluation = f(*Xs)
if return_info:
info['Xs'] = torch.cat((info['Xs'], torch.cat([x[:, None] for x in Xs], dim=1).detach().cpu()), dim=0)
info['shapes'].append([Rs[j], Is[j], Rs[j + 1]])
V = torch.reshape(evaluation, [Rs[j], Is[j], Rs[j + 1]])
return V
cores = []
# Cross-interpolation formula, left-to-right
for j in range(0, N-1):
# Update tensors for current indices
V = evaluate_function(j)
V = torch.reshape(V, [-1, V.shape[2]]) # Left unfolding
A = V[left_locals[j], :]
X = torch.linalg.lstsq(A.t(), V.t()).solution.t()
cores.append(X.reshape(Rs[j], Is[j], Rs[j + 1]))
# Map local indices to global ones
local_r, local_i = np.unravel_index(left_locals[j], [Rs[j], Is[j]])
lsets[j + 1] = np.c_[lsets[j][local_r, :], local_i]
for k, t in enumerate(tensors):
t_linterfaces[k][j + 1] = torch.einsum('ai,iaj->aj',
[t_linterfaces[k][j][local_r, :], t.cores[j][:, local_i, :]])
# Leave the last core ready
X = evaluate_function(N-1)
cores.append(X)
if return_info:
return tn.Tensor(cores), info
else:
return tn.Tensor(cores)
Computing file changes ...