import tensorly as tl
# TODO : add batched_modes as in batched_tensor_dot?
def batched_outer(tensors):
"""Returns a generalized outer product of the two tensors
Parameters
----------
tensor1 : tensor
of shape (n_samples, J1, ..., JN)
tensor2 : tensor
of shape (n_samples, K1, ..., KM)
Returns
-------
outer product of tensor1 and tensor2
of shape (n_samples, J1, ..., JN, K1, ..., KM)
"""
for i, tensor in enumerate(tensors):
if i:
shape = tl.shape(tensor)
size = len(shape) - 1
n_samples = shape[0]
if n_samples != shape_res[0]:
raise ValueError(
f"Tensor {i} has a batch-size of {n_samples} but those before had a batch-size of {shape_res[0]}, "
"all tensors should have the same batch-size."
)
shape_1 = shape_res + (1,) * size
shape_2 = (n_samples,) + (1,) * size_res + shape[1:]
res = tl.reshape(res, shape_1) * tl.reshape(tensor, shape_2)
else:
res = tensor
shape_res = tl.shape(res)
size_res = len(shape_res) - 1
return res
def outer(tensors):
"""Returns a generalized outer product of the two tensors
Parameters
----------
tensor1 : tensor
of shape (J1, ..., JN)
tensor2 : tensor
of shape (K1, ..., KM)
Returns
-------
outer product of tensor1 and tensor2
of shape (J1, ..., JN, K1, ..., KM)
"""
for i, tensor in enumerate(tensors):
if i:
shape = tl.shape(tensor)
s1 = len(shape)
shape_1 = shape_res + (1,) * s1
shape_2 = (1,) * sres + shape
res = tl.reshape(res, shape_1) * tl.reshape(tensor, shape_2)
else:
res = tensor
shape_res = tl.shape(res)
sres = len(shape_res)
return res