{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import tensorly as tl\n",
"tl.set_backend('pytorch')\n",
"\n",
"from tensorly.decomposition import parafac\n",
"from tensorly.cp_tensor import cp_to_tensor, CPTensor"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Generate tensor\n",
"rank = 70\n",
"t_shape = (10, 20, 30)\n",
"\n",
"factors = [torch.randn(l, rank) for l in t_shape]\n",
"weights = tl.ones(rank)\n",
"\n",
"t = cp_to_tensor(CPTensor((weights, factors)))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"approx error: 114.89116668701172, non-zero elements in approx: 4200\n"
]
}
],
"source": [
"# Parafac\n",
"r = rank\n",
"\n",
"cp = parafac(t, r, cvg_criterion = 'rec_error')\n",
"approx = cp_to_tensor(cp)\n",
"\n",
"approx_error = tl.norm(t - approx)\n",
"nonzero_elems_count = sum(t_shape)*r\n",
"\n",
"print('approx error: {}, non-zero elements in approx: {}'.format(approx_error,\\\n",
" nonzero_elems_count))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"approx error: 112.28209686279297, non-zero elements in approx: 4200.0\n"
]
}
],
"source": [
"# Sparse parafac\n",
"r = rank-5\n",
"sparsity = .05\n",
"\n",
"cp, sparse = parafac(t, r, sparsity=sparsity, cvg_criterion = 'rec_error')\n",
"approx_sp = cp_to_tensor(cp) + sparse\n",
"\n",
"approx_error = tl.norm(t - approx_sp)\n",
"nonzero_elems_count = sum(t_shape)*r + sparsity*np.prod(t_shape)\n",
"\n",
"print('approx error: {}, non-zero elements in approx: {}'.format(approx_error,\\\n",
" nonzero_elems_count))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}