Revision ac62d292d7de9b49855b7ca53f2645a10c63246a authored by rballester on 16 September 2022, 15:17:15 UTC, committed by rballester on 16 September 2022, 15:17:15 UTC
1 parent 363d461
util.py
import numpy as np
import tntorch as tn
def random_format(shape):
"""
Generate a random tensor of random format (often hybrid) with the given shape
:param shape:
:return: a tensor
"""
N = len(shape)
if np.random.randint(4) == 0:
ranks_tucker = None
else:
ranks_tucker= [None]*N
for n in sorted(np.random.choice(N, np.random.randint(N+1), replace=False)):
ranks_tucker[n] = np.random.randint(1, 5)
if np.random.randint(4) == 0:
ranks_tt = None
ranks_cp = np.random.randint(1, 5)
elif np.random.randint(4) == 0:
ranks_cp = None
ranks_tt = np.random.randint(1, 5, N-1)
else:
ranks_tt = list(np.random.randint(1, 5, N-1))
ranks_cp = [None]*N
for n in sorted(np.random.choice(N, np.random.randint(N+1), replace=False)):
if n > 0 and ranks_cp[n-1] is not None:
r = ranks_cp[n-1]
else:
r = np.random.randint(1, 5)
ranks_cp[n] = r
if n > 0:
ranks_tt[n-1] = None
if n < N-1:
ranks_tt[n] = None
return tn.randn(shape, ranks_tt=ranks_tt, ranks_cp=ranks_cp, ranks_tucker=ranks_tucker)
Computing file changes ...