Revision aaae58cc70f03ac357af64aa1300ab00eaf9bb6d authored by JeanKossaifi on 23 October 2016, 18:50:41 UTC, committed by JeanKossaifi on 23 October 2016, 18:50:41 UTC
1 parent 1e94ea4
_khatri_rao.py
import numpy as np
# Author: Jean Kossaifi
def khatri_rao(matrices, skip_matrix=None, reverse=False):
"""Khatri-Rao product of a list of matrices
This can be seen as a column-wise kronecker product
(see [1]_ for more details).
Parameters
----------
matrices : ndarray list
list of matrices with the same number of columns, i.e.::
for i in len(matrices):
matrices[i] is a matrix of shape `(n_i, m)`
skip_matrix : None or int, optional, default is None
if not None, index of a matrix to skip
reverse : bool, optional
if True, the order of the matrices is reversed
Returns
-------
khatri_rao_product: matrix of shape `(prod(n_i), m)`
where `prod(n_i) = prod([m.shape[0] for m in matrices])`
i.e. the product of the number of rows of all the matrices in the product.
Notes
-----
Mathematically:
.. math::
\\text{If every matrix } U_k \\text{ is of size } (I_k \\times R),\\\\
\\text{Then } \\left(U_1 \\bigodot \\cdots \\bigodot U_n \\right) \\text{ is of size } (\\prod_{k=1}^n I_k \\times R)
A more intuitive but slower implementation is::
kr_product = np.zeros((n_rows, n_columns))
for i in range(n_columns):
cum_prod = matrices[0][:, i] # Acuumulates the khatri-rao product of the i-th columns
for matrix in matrices[1:]:
cum_prod = np.einsum('i,j->ij', cum_prod, matrix[:, i]).ravel()
# the i-th column corresponds to the kronecker product of all the i-th columns of all matrices:
kr_product[:, i] = cum_prod
return kr_product
References
----------
.. [1] T.G.Kolda and B.W.Bader, "Tensor Decompositions and Applications",
SIAM REVIEW, vol. 51, n. 3, pp. 455-500, 2009.
"""
if skip_matrix is not None:
matrices = [matrices[i] for i in range(len(matrices)) if i != skip_matrix]
n_columns = matrices[0].shape[1]
# Optional part, testing whether the matrices have the proper size
for i, matrix in enumerate(matrices):
if matrix.ndim != 2:
raise ValueError('All the matrices must have exactly 2 dimensions!'
'Matrix {} has dimension {} != 2.'.format(
i, matrix.ndim))
if matrix.shape[1] != n_columns:
raise ValueError('All matrices must have same number of columns!'
'Matrix {} has {} columns != {}.'.format(
i, matrix.shape[1], n_columns))
n_factors = len(matrices)
if reverse:
matrices = matrices[::-1]
# Note: we do NOT use .reverse() which would reverse matrices even outside this function
start = ord('a')
common_dim = 'z'
target = ''.join(chr(start + i) for i in range(n_factors))
source = ','.join(i+common_dim for i in target)
operation = source+'->'+target+common_dim
return np.einsum(operation, *matrices).reshape((-1, n_columns))
Computing file changes ...