swh:1:snp:4e3e7077647a709f15b8c1b32ce7100175d0580b
Tip revision: 0fa8ab5b81e54410ee9b34a1ac5e45eebda7e387 authored by Jean Kossaifi on 06 August 2019, 17:16:06 UTC
Upload only once
Upload only once
Tip revision: 0fa8ab5
proximal.py
import tensorly as tl
# Author: Jean Kossaifi
# License: BSD 3 clause
def soft_thresholding(tensor, threshold):
"""Soft-thresholding operator
sign(tensor) * max[abs(tensor) - threshold, 0]
Parameters
----------
tensor : ndarray
threshold : float or ndarray with shape tensor.shape
* If float the threshold is applied to the whole tensor
* If ndarray, one theshold is applied per elements, 0 values are ignored
Returns
-------
ndarray
thresholded tensor on which the operator has been applied
Examples
--------
Basic shrinkage
>>> import tensorly.backend as T
>>> from tensorly.tenalg.proximal import soft_thresholding
>>> tensor = tl.tensor([[1, -2, 1.5], [-4, 3, -0.5]])
>>> soft_thresholding(tensor, 1.1)
array([[ 0. , -0.9, 0.4],
[-2.9, 1.9, 0. ]])
Example with missing values
>>> mask = tl.tensor([[0, 0, 1], [1, 0, 1]])
>>> soft_thresholding(tensor, mask*1.1)
array([[ 1. , -2. , 0.4],
[-2.9, 3. , 0. ]])
See also
--------
inplace_soft_thresholding : Inplace version of the soft-thresholding operator
svd_thresholding : SVD-thresholding operator
"""
return tl.sign(tensor)*tl.clip(tl.abs(tensor) - threshold, a_min=0)
def svd_thresholding(matrix, threshold):
"""Singular value thresholding operator
Parameters
----------
matrix : ndarray
threshold : float
Returns
-------
ndarray
matrix on which the operator has been applied
See also
--------
procrustes : procrustes operator
"""
U, s, V = tl.partial_svd(matrix, n_eigenvecs=min(matrix.shape))
return tl.dot(U, tl.reshape(soft_thresholding(s, threshold), (-1, 1))*V)
def procrustes(matrix):
"""Procrustes operator
Parameters
----------
matrix : ndarray
Returns
-------
ndarray
matrix on which the Procrustes operator has been applied
has the same shape as the original tensor
See also
--------
svd_thresholding : SVD-thresholding operator
"""
U, _, V = tl.partial_svd(matrix, n_eigenvecs=min(matrix.shape))
return tl.dot(U, V)