Raw File
test__khatri_rao.py
from numpy.testing import assert_raises, assert_array_equal, assert_array_almost_equal
import numpy as np

from .._khatri_rao import khatri_rao


# Author: Jean Kossaifi


def test_khatri_rao():
    """Test for khatri_rao
    """
    columns = 4
    rows = [3, 4, 2]
    matrices = [np.arange(k * columns).reshape((k, columns)) for k in rows]
    res = khatri_rao(matrices)
    # resulting matrix must be of shape (prod(n_rows), n_columns)
    n_rows = 3 * 4 * 2
    n_columns = 4
    assert (res.shape[0] == n_rows)
    assert (res.shape[1] == n_columns)

    # fail case: all matrices must have same number of columns
    shapes = [[3, 4], [3, 4], [3, 2]]
    matrices = [np.arange(i * j).reshape((i, j)) for (i, j) in shapes]
    with assert_raises(ValueError):
        khatri_rao(matrices)

    # all matrices should be of dim 2...
    matrices = [np.eye(3), np.arange(3 * 2 * 2).reshape((3, 2, 2))]
    with assert_raises(ValueError):
        khatri_rao(matrices)

    # Classic example/test
    t1 = np.array([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])
    t2 = np.array([[1, 4, 7],
                   [2, 5, 8],
                   [3, 6, 9]])
    true_res = np.array([[1., 8., 21.],
                         [2., 10., 24.],
                         [3., 12., 27.],
                         [4., 20., 42.],
                         [8., 25., 48.],
                         [12., 30., 54.],
                         [7., 32., 63.],
                         [14., 40., 72.],
                         [21., 48., 81.]])
    reversed_true_res = np.array([[1., 8., 21.],
                                  [4., 20., 42.],
                                  [7., 32., 63.],
                                  [2., 10., 24.],
                                  [8., 25., 48.],
                                  [14., 40., 72.],
                                  [3., 12., 27.],
                                  [12., 30., 54.],
                                  [21., 48., 81.]])
    res = khatri_rao([t1, t2])
    assert_array_equal(res, true_res)
    reversed_res = khatri_rao([t1, t2], reverse=True)
    assert_array_equal(reversed_res, reversed_true_res)

    # A = np.hstack((np.eye(3), np.arange(3)[:, None]))
    A = np.array([[ 1.,  0.,  0.,  0.],
                  [ 0.,  1.,  0.,  1.],
                  [ 0.,  0.,  1.,  2.]])
    B = np.array([[ 1.,  0.,  0.,  3.],
                  [ 0.,  1.,  0.,  4.],
                  [ 0.,  0.,  1.,  5.]])
    true_res = np.array([[  1.,   0.,   0.,   0.],
                         [  0.,   0.,   0.,   0.],
                         [  0.,   0.,   0.,   0.],
                         [  0.,   0.,   0.,   3.],
                         [  0.,   1.,   0.,   4.],
                         [  0.,   0.,   0.,   5.],
                         [  0.,   0.,   0.,   6.],
                         [  0.,   0.,   0.,   8.],
                         [  0.,   0.,   1.,  10.]])
    assert_array_equal(khatri_rao([A, B]), true_res)

    U1 = np.reshape(np.arange(1, 10), (3, 3))
    U2 = np.reshape(np.arange(10, 22), (4, 3))
    U3 = np.reshape(np.arange(22, 28), (2, 3))
    U4 = np.reshape(np.arange(28, 34), (2, 3))
    U = [U1, U2, U3, U4]
    true_res = true_res = np.array([[  6160,        14674,        25920],
                                  [  6820,        16192,        28512],
                                  [  7000,        16588,        29160],
                                  [  7750,        18304,        32076],
                                  [  8008,        18676,        32400],
                                  [  8866,        20608,        35640],
                                  [  9100,        21112,        36450],
                                  [ 10075,        23296,        40095],
                                  [  9856,        22678,        38880],
                                  [ 10912,        25024,        42768],
                                  [ 11200,        25636,        43740],
                                  [ 12400,        28288,        48114],
                                  [ 11704,        26680,        45360],
                                  [ 12958,        29440,        49896],
                                  [ 13300,        30160,        51030],
                                  [ 14725,        33280,        56133],
                                  [ 24640,        36685,        51840],
                                  [ 27280,        40480,        57024],
                                  [ 28000,        41470,        58320],
                                  [ 31000,        45760,        64152],
                                  [ 32032,        46690,        64800],
                                  [ 35464,        51520,        71280],
                                  [ 36400,        52780,        72900],
                                  [ 40300,        58240,        80190],
                                  [ 39424,        56695,        77760],
                                  [ 43648,        62560,        85536],
                                  [ 44800,        64090,        87480],
                                  [ 49600,        70720,        96228],
                                  [ 46816,        66700,        90720],
                                  [ 51832,        73600,        99792],
                                  [ 53200,        75400,       102060],
                                  [ 58900,        83200,       112266],
                                  [ 43120,        58696,        77760],
                                  [ 47740,        64768,        85536],
                                  [ 49000,        66352,        87480],
                                  [ 54250,        73216,        96228],
                                  [ 56056,        74704,        97200],
                                  [ 62062,        82432,       106920],
                                  [ 63700,        84448,       109350],
                                  [ 70525,        93184,       120285],
                                  [ 68992,        90712,       116640],
                                  [ 76384,       100096,       128304],
                                  [ 78400,       102544,       131220],
                                  [ 86800,       113152,       144342],
                                  [ 81928,       106720,       136080],
                                  [ 90706,       117760,       149688],
                                  [ 93100,       120640,       153090],
                                  [103075,       133120,       168399]])
    res = khatri_rao(U)
    assert_array_equal(res, true_res)

    res_1 = khatri_rao(U, skip_matrix=1)
    res_2 = khatri_rao([U[0]] + U[2:])
    assert_array_equal(res_1, res_2)

    res_1 = khatri_rao(U, skip_matrix=0)
    res_2 = khatri_rao(U[1:])
    assert_array_equal(res_1, res_2)
back to top