from ..tenalg_utils import _validate_contraction_modes
from ...utils import prod
import tensorly as tl
def tensordot(tensor1, tensor2, modes, batched_modes=()):
"""Batched tensor contraction between two tensors on specified modes
Parameters
----------
tensor1 : tl.tensor
tensor2 : tl.tensor
modes : int list or int
modes on which to contract tensor1 and tensor2
batched_modes : int or tuple[int]
Returns
-------
contraction : tensor1 contracted with tensor2 on the specified modes
"""
modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes)
batch_modes1, batch_modes2 = _validate_contraction_modes(
tensor1.shape, tensor2.shape, batched_modes, batched_modes=True
)
contraction_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in modes1]
contraction_dim = prod(contraction_shape)
batch_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in batch_modes1]
# Prepare to reorganize the modes afterwards by moving bactch size back to their place
# (while ommiting modes contracted over)
final_modes = []
n_batches = len(batch_modes1)
batch_counter = 0
free_counter = 0
for i in range(tl.ndim(tensor1)):
if i in modes1:
continue
elif i in batch_modes1:
final_modes.append(batch_counter)
batch_counter += 1
else:
final_modes.append(free_counter + n_batches)
free_counter += 1
# We will reorganize tensor1 to (batch_modes, new_modes1, contraction_modes)
new_modes1 = [i for i in range(tl.ndim(tensor1)) if i not in batch_modes1 + modes1]
new_shape1 = [tl.shape(tensor1)[i] for i in new_modes1]
tensor1 = tl.transpose(tensor1, batch_modes1 + new_modes1 + modes1)
tensor1 = tl.reshape(tensor1, (*batch_shape, -1, contraction_dim))
# Tensor2 will be (batch_modes, contraction_modes, new_modes2)
new_modes2 = [i for i in range(tl.ndim(tensor2)) if i not in batch_modes2 + modes2]
new_shape2 = [tl.shape(tensor2)[i] for i in new_modes2]
tensor2 = tl.transpose(tensor2, batch_modes2 + modes2 + new_modes2)
tensor2 = tl.reshape(tensor2, (*batch_shape, contraction_dim, -1))
res = tl.matmul(tensor1, tensor2)
res = tl.reshape(res, (*batch_shape, *new_shape1, *new_shape2))
final_modes += [i for i in range(res.ndim) if i not in final_modes]
if final_modes:
res = tl.transpose(res, final_modes)
return res