https://github.com/tensorly/tensorly
_khatri_rao.py
from ... import backend as T
import warnings
# Author: Jean Kossaifi
# License: BSD 3 clause
def khatri_rao(matrices, weights=None, skip_matrix=None, reverse=False, mask=None):
"""Khatri-Rao product of a list of matrices
This can be seen as a column-wise kronecker product.
(see [1]_ for more details).
If one matrix only is given, that matrix is directly returned.
Parameters
----------
matrices : 2D-array list
list of matrices with the same number of columns, i.e.::
for i in len(matrices):
matrices[i].shape = (n_i, m)
weights : 1D-array
array of weights for each rank, of length m, the number of column of the factors
(i.e. m == factor[i].shape[1] for any factor)
skip_matrix : None or int, optional, default is None
if not None, index of a matrix to skip
reverse : bool, optional
if True, the order of the matrices is reversed
Returns
-------
khatri_rao_product: matrix of shape ``(prod(n_i), m)``
where ``prod(n_i) = prod([m.shape[0] for m in matrices])``
i.e. the product of the number of rows of all the matrices in the product.
Notes
-----
Mathematically:
.. math::
\\text{If every matrix } U_k \\text{ is of size } (I_k \\times R),\\\\
\\text{Then } \\left(U_1 \\bigodot \\cdots \\bigodot U_n \\right) \\text{ is of size } (\\prod_{k=1}^n I_k \\times R)
A more intuitive but slower implementation is::
kr_product = np.zeros((n_rows, n_columns))
for i in range(n_columns):
cum_prod = matrices[0][:, i] # Accumulates the khatri-rao product of the i-th columns
for matrix in matrices[1:]:
cum_prod = np.einsum('i,j->ij', cum_prod, matrix[:, i]).ravel()
# the i-th column corresponds to the kronecker product of all the i-th columns of all matrices:
kr_product[:, i] = cum_prod
return kr_product
References
----------
.. [1] T.G.Kolda and B.W.Bader, "Tensor Decompositions and Applications",
SIAM REVIEW, vol. 51, n. 3, pp. 455-500, 2009.
"""
if skip_matrix is not None:
matrices = [matrices[i] for i in range(len(matrices)) if i != skip_matrix]
# Khatri-rao of only one matrix: just return that matrix
if len(matrices) == 1:
return matrices[0]
if T.ndim(matrices[0]) == 2:
n_columns = matrices[0].shape[1]
else:
n_columns = 1
matrices = [T.reshape(m, (-1, 1)) for m in matrices]
warnings.warn(
"Khatri-rao of a series of vectors instead of matrices. "
"Condidering each has a matrix with 1 column."
)
# Optional part, testing whether the matrices have the proper size
for i, matrix in enumerate(matrices):
if T.ndim(matrix) != 2:
raise ValueError(
"All the matrices must have exactly 2 dimensions!"
f"Matrix {i} has dimension {T.ndim(matrix)} != 2."
)
if matrix.shape[1] != n_columns:
raise ValueError(
"All matrices must have same number of columns!"
f"Matrix {i} has {matrix.shape[1]} columns != {n_columns}."
)
if reverse:
matrices = matrices[::-1]
# Note: we do NOT use .reverse() which would reverse matrices even outside this function
return T.kr(matrices, weights=weights, mask=mask)