https://github.com/tensorly/tensorly
Tip revision: 1295ccb09626f89f20d0c0183d618f96b4833bf1 authored by Jean Kossaifi on 08 May 2018, 21:04:53 UTC
FIX README + bump version
FIX README + bump version
Tip revision: 1295ccb
test_tucker_regression.py
import numpy as np
from ..tucker_regression import TuckerRegressor
from ...base import tensor_to_vec, partial_tensor_to_vec
from ...metrics.regression import RMSE
from ... import backend as T
def test_TuckerRegressor():
"""Test for TuckerRegressor"""
# Parameter of the experiment
image_height = 10
image_width = 10
n_channels = 3
ranks = [5, 5, 2]
tol = 0.05
# Generate random samples
X = T.tensor(np.random.normal(size=(1200, image_height, image_width, n_channels), loc=0, scale=1))
regression_weights = np.zeros((image_height, image_width, n_channels))
regression_weights[2:-2, 2:-2, 0] = 1
regression_weights[2:-2, 2:-2, 1] = 2
regression_weights[2:-2, 2:-2, 2] = -1
regression_weights = T.tensor(regression_weights)
y = T.dot(partial_tensor_to_vec(X, skip_begin=1), tensor_to_vec(regression_weights))
X_train = X[:1000, :, :]
X_test = X[1000:, :, :]
y_train = y[:1000]
y_test = y[1000:]
estimator = TuckerRegressor(weight_ranks=ranks, tol=10e-8, reg_W=1, n_iter_max=200, verbose=True)
estimator.fit(X_train, y_train)
y_pred = estimator.predict(X_test)
error = RMSE(y_test, y_pred)
T.assert_(error <= tol, msg='Tucker Regression : RMSE={} > {}'.format(error, tol))
params = estimator.get_params()
T.assert_(params['weight_ranks'] == [5, 5, 2], msg='get_params did not return the correct parameters')
params['weight_ranks'] = [5, 5, 5]
estimator.set_params(**params)
T.assert_(estimator.weight_ranks == [5, 5, 5], msg='set_params did not correctly set the given parameters')