from ..tenalg_utils import _validate_contraction_modes from ...utils import prod import tensorly as tl def tensordot(tensor1, tensor2, modes, batched_modes=()): """Batched tensor contraction between two tensors on specified modes Parameters ---------- tensor1 : tl.tensor tensor2 : tl.tensor modes : int list or int modes on which to contract tensor1 and tensor2 batched_modes : int or tuple[int] Returns ------- contraction : tensor1 contracted with tensor2 on the specified modes """ modes1, modes2 = _validate_contraction_modes(tensor1.shape, tensor2.shape, modes) batch_modes1, batch_modes2 = _validate_contraction_modes( tensor1.shape, tensor2.shape, batched_modes, batched_modes=True ) contraction_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in modes1] contraction_dim = prod(contraction_shape) batch_shape = [s for (i, s) in enumerate(tl.shape(tensor1)) if i in batch_modes1] # Prepare to reorganize the modes afterwards by moving bactch size back to their place # (while ommiting modes contracted over) final_modes = [] n_batches = len(batch_modes1) batch_counter = 0 free_counter = 0 for i in range(tl.ndim(tensor1)): if i in modes1: continue elif i in batch_modes1: final_modes.append(batch_counter) batch_counter += 1 else: final_modes.append(free_counter + n_batches) free_counter += 1 # We will reorganize tensor1 to (batch_modes, new_modes1, contraction_modes) new_modes1 = [i for i in range(tl.ndim(tensor1)) if i not in batch_modes1 + modes1] new_shape1 = [tl.shape(tensor1)[i] for i in new_modes1] tensor1 = tl.transpose(tensor1, batch_modes1 + new_modes1 + modes1) tensor1 = tl.reshape(tensor1, (*batch_shape, -1, contraction_dim)) # Tensor2 will be (batch_modes, contraction_modes, new_modes2) new_modes2 = [i for i in range(tl.ndim(tensor2)) if i not in batch_modes2 + modes2] new_shape2 = [tl.shape(tensor2)[i] for i in new_modes2] tensor2 = tl.transpose(tensor2, batch_modes2 + modes2 + new_modes2) tensor2 = tl.reshape(tensor2, (*batch_shape, contraction_dim, -1)) res = tl.matmul(tensor1, tensor2) res = tl.reshape(res, (*batch_shape, *new_shape1, *new_shape2)) final_modes += [i for i in range(res.ndim) if i not in final_modes] if final_modes: res = tl.transpose(res, final_modes) return res