https://github.com/rballester/tntorch
Revision 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC, committed by rballester on 19 February 2023, 19:35:27 UTC
1 parent be80cb2
Raw File
Tip revision: 241bf7ad2b806f6677a5e23534247f35f3a70f10 authored by rballester on 19 February 2023, 19:35:27 UTC
Exact method for moments
Tip revision: 241bf7a
autodiff.py
import tntorch as tn
import torch
import numpy as np
import time
from functools import reduce


def optimize(tensors, loss_function, optimizer=torch.optim.Adam, tol=1e-4, max_iter=1e4, print_freq=500, verbose=True):
    """
    High-level wrapper for iterative learning.

    Default stopping criterion: either the absolute (or relative) loss improvement must fall below `tol`.
    In addition, the rate loss improvement must be slowing down.

    :param tensors: one or several tensors; will be fed to `loss_function` and optimized in place
    :param loss_function: must take `tensors` and return a scalar (or tuple thereof)
    :param optimizer: one from https://pytorch.org/docs/stable/optim.html. Default is torch.optim.Adam
    :param tol: stopping criterion
    :param max_iter: default is 1e4
    :param print_freq: progress will be printed every this many iterations
    :param verbose:
    """

    if not isinstance(tensors, (list, tuple)):
        tensors = [tensors]
    parameters = []
    for t in tensors:
        if isinstance(t, tn.Tensor):
            if t.batch:
                raise ValueError('Batched tensors are not supproted.')

            parameters.extend([c for c in t.cores if c.requires_grad])
            parameters.extend([U for U in t.Us if U is not None and U.requires_grad])
        elif t.requires_grad:
            parameters.append(t)
    if len(parameters) == 0:
        raise ValueError("There are no parameters to optimize. Did you forget a requires_grad=True somewhere?")

    optimizer = optimizer(parameters)
    losses = []
    converged = False
    start = time.time()
    iter = 0
    while True:
        optimizer.zero_grad()
        loss = loss_function(*tensors)
        if not isinstance(loss, (tuple, list)):
            loss = [loss]
        total_loss = sum(loss)
        total_loss.backward(retain_graph=True)
        optimizer.step()

        losses.append(total_loss.detach())
        

        if len(losses) >= 2:
            delta_loss = (losses[-1] - losses[-2])
        else:
            delta_loss = float('-inf')
        if iter >= 2 and tol is not None and (losses[-1] <= tol or -delta_loss / losses[-1] <= tol) and losses[-2] - \
                losses[-1] < losses[-3] - losses[-2]:
            converged = True
            break
        if iter == max_iter:
            break
        if verbose and iter % print_freq == 0:
            print('iter: {: <{}} | loss: '.format(iter, len('{}'.format(max_iter))), end='')
            print(' + '.join(['{:10.6f}'.format(l.item()) for l in loss]), end='')
            if len(loss) > 1:
                print(' = {:10.4}'.format(losses[-1].item()), end='')
            print(' | total time: {:9.4f}'.format(time.time() - start))

        iter += 1
    if verbose:
        print('iter: {: <{}} | loss: '.format(iter, len('{}'.format(max_iter))), end='')
        print(' + '.join(['{:10.6f}'.format(l.item()) for l in loss]), end='')
        if len(loss) > 1:
            print(' = {:10.4}'.format(losses[-1].item()), end='')
        print(' | total time: {:9.4f}'.format(time.time() - start), end='')
        if converged:
            print(' <- converged (tol={})'.format(tol))
        else:
            print(' <- max_iter was reached: {}'.format(max_iter))


def dof(t):
    """
    Compute the number of degrees of freedom of a tensor network.

    It is the sum of sizes of all its tensor nodes that have the requires_grad=True flag.

    :param t: input tensor

    :return: an integer
    """

    result = 0
    for n in range(t.dim()):
        if t.cores[n].requires_grad:
            result += t.cores[n].numel()
        if t.Us[n] is not None and t.Us[n].requires_grad:
            result += t.Us[n].numel()
    return result
back to top