https://github.com/palash1992/DynamicGEM
Tip revision: 911fe36dc8fa5f85deb0d931210cac823955f643 authored by Palash Goyal on 07 April 2020, 19:10:13 UTC
Modified metrics defn
Modified metrics defn
Tip revision: 911fe36
evaluation_util.py
import numpy as np
from random import randint
def getRandomEdgePairs(node_num, sample_ratio=0.01, is_undirected=True):
num_pairs = int(sample_ratio * node_num * (node_num - 1))
if is_undirected:
num_pairs = num_pairs / 2
current_sets = set()
while (len(current_sets) < num_pairs):
p = (randint(node_num), randint(node_num))
if (p in current_sets):
continue
if (is_undirected and (p[1], p[0]) in current_sets):
continue
current_sets.add(p)
return list(current_sets)
def getEdgeListFromAdjMtx(adj, threshold=0.0, is_undirected=True, edge_pairs=None):
result = []
node_num = adj.shape[0]
if edge_pairs:
for (st, ed) in edge_pairs:
if adj[st, ed] >= threshold:
result.append((st, ed, adj[st, ed]))
else:
for i in range(node_num):
for j in range(node_num):
if (j == i):
continue
if (is_undirected and i >= j):
continue
if adj[i, j] > threshold:
result.append((i, j, adj[i, j]))
return result
def splitDiGraphToTrainTest(di_graph, train_ratio, is_undirected=True):
train_digraph = di_graph.copy()
test_digraph = di_graph.copy()
node_num = di_graph.number_of_nodes()
for (st, ed, w) in di_graph.edges_iter(data='weight', default=1):
if (is_undirected and st >= ed):
continue
if (np.random.uniform() <= train_ratio):
test_digraph.remove_edge(st, ed)
if (is_undirected):
test_digraph.remove_edge(ed, st)
else:
train_digraph.remove_edge(st, ed)
if (is_undirected):
train_digraph.remove_edge(ed, st)
return (train_digraph, test_digraph)