import torch import numpy as np import src.debfly.tree import src.partition import src.debfly.product from src.cluster_tree import ClusterTree def col_permute(matrix, permutation): """ permute columns of matrix according to permutation """ assert len(matrix.shape) == 2 if not isinstance(permutation, torch.LongTensor): permutation = torch.LongTensor(permutation) assert len(permutation.shape) == 1 return matrix[:, permutation] def row_permute(matrix, permutation): """ permute rows of matrix according to permutation """ assert len(matrix.shape) == 2 if not isinstance(permutation, torch.LongTensor): permutation = torch.LongTensor(permutation) assert len(permutation.shape) == 1 return matrix[permutation, :] def permutation_to_matrix(permutation): """ permutation is a tuple """ n = len(permutation) return col_permute(torch.eye(n), permutation) def inverse_permutation(permutation): """ compute the inverse permutation """ permutation = np.array(permutation) inverse_perm = np.argsort(permutation) return tuple(inverse_perm) def composition(a, b): """ permute according to a, then permute according to b """ if not isinstance(a, torch.LongTensor): a = torch.LongTensor(a) if not isinstance(b, torch.LongTensor): b = torch.LongTensor(b) return a[b] def perm_col_partition_to_monarch_canonical(col_partition, tree): """ return permutation to go from col_partition to monarch canonical """ src.partition.check_col_partition_is_monarch_compatible(col_partition, tree) p = torch.LongTensor(col_partition) p = p.view(-1) return p def _perm_row_canonical_to_block(diag, row): p = torch.LongTensor(src.partition.canonical_row_partition_monarch(diag, row)) return p.view(-1) def perm_row_partition_to_monarch_canonical(row_partition, tree): """ permutation to go from row_partition to monarch canoncial """ src.partition.check_row_partition_is_monarch_compatible(row_partition, tree) diag = len(row_partition) row = len(row_partition[0]) p1 = torch.LongTensor(row_partition) p1 = p1.view(-1) perm = composition(p1, inverse_permutation(_perm_row_canonical_to_block(diag, row))) return perm def perm_row_tree_to_butterfly_canonical(row_cluster_tree): """ permutation to go from the row cluster tree to the canonical butterfly row cluster tree """ assert isinstance(row_cluster_tree, ClusterTree) size = len(row_cluster_tree.node) log_n = int(np.log2(size)) assert size == 2**log_n leaves = np.squeeze(np.array(row_cluster_tree.nodes_at_level(log_n)), axis=1).tolist() canonical_tree = src.cluster_tree.even_odd_split_cluster_tree(list(range(size))) canonical_leaves = np.squeeze(np.array(canonical_tree.nodes_at_level(log_n)), axis=1).tolist() return composition(leaves, inverse_permutation(canonical_leaves)) def perm_col_tree_to_butterfly_canonical(col_cluster_tree): """ permutation to go from the column cluster tree to the canonical butterfly column cluster tree """ assert isinstance(col_cluster_tree, ClusterTree) size = len(col_cluster_tree.node) log_n = int(np.log2(size)) assert size == 2 ** log_n leaves = np.squeeze(np.array(col_cluster_tree.nodes_at_level(log_n)), axis=1).tolist() canonical_tree = src.cluster_tree.middle_split_cluster_tree(list(range(size))) canonical_leaves = np.squeeze(np.array(canonical_tree.nodes_at_level(log_n)), axis=1).tolist() return composition(leaves, inverse_permutation(canonical_leaves)) def perm_DFT(num_factors): result = [] size = 2 ** num_factors for i in range(num_factors): if i == 0: continue for j in range(3): z = perm_type(i + 1, j) result.append(np.kron(np.identity(2 ** (num_factors - 1 - i)), z)) return result def perm_type(i, type): """ Type 0 is c in paper. Type 1 is b in paper. Type 2 is a in paper. :param i: :param type: :return: """ size = 2 ** i result = np.zeros((size,size)) if type == 0: result[np.arange(size//2), np.arange(size//2)] = 1 result[size//2 + np.arange(size//2), size - 1 - np.arange(size//2)] = 1 elif type == 1: result[size // 2 - 1 - np.arange(size//2), np.arange(size//2)] = 1 result[size // 2 + np.arange(size//2), size//2 + np.arange(size//2)] = 1 else: result[np.arange(size//2), np.arange(size//2) * 2] = 1 result[size//2 + np.arange(size//2), np.arange(size//2) * 2 + 1] = 1 return result def bit_reversal_permutation(log_n): perm_mat = src.debfly.product.matrix_product(perm_DFT(log_n)) return perm_mat if __name__ == '__main__': p = (3, 2, 0, 1) mat = torch.randn(4, 4) print(mat @ permutation_to_matrix(p) == col_permute(mat, p)) print(permutation_to_matrix(p).t() @ mat == row_permute(mat, p)) print(inverse_permutation(p)) print(col_permute(col_permute(mat, inverse_permutation(p)), p) == mat) print(col_permute(col_permute(mat, p), inverse_permutation(p)) == mat) print(row_permute(row_permute(mat, inverse_permutation(p)), p) == mat) print(row_permute(row_permute(mat, p), inverse_permutation(p)) == mat) p1 = (0, 1, 3, 2) p2 = (0, 2, 1, 3) p1_then_p2 = composition(p1, p2) print(p1_then_p2) print(col_permute(col_permute(mat, p1), p2) == col_permute(mat, p1_then_p2))