https://bitbucket.org/NetaRS/sched_analytics
Raw File
Tip revision: ed1f2acca39de9eb5f34a6cb5b0c8db1492f74f2 authored by NetaRS on 12 December 2020, 09:53:39 UTC
bounded traffic distributed
Tip revision: ed1f2ac
distributed_win.py
import sys
from os import path
import json
import random
import argparse
from matrices import MatrixList, add_with_matrix, make_zero_matrix, subtract_with_matrix, multiply_with_scalar
from mwm import DictList
from collections import defaultdict
from copy import deepcopy
# import pdb; pdb.set_trace()
# from pudb import set_trace; set_trace()
from numpy import inf
# This script runs our distributed algorithm, which in turn is based on iSLIP
from compute_tp import SeriesStats


DIR = "outputs/run0/"

DEBUG_ITERATION = False

# computing timeslot [x,x+1]
# filename contains traffic for [x,x+1] --> loaded to traffic
# current_matching is the matching used in x
# prev_traffic corresponds to timeslots [x-1,x]
# mwm_matching is the last computed mwm, its per millisecond weights are in weight_mwm


def update_distributed_pairing(last_pairing, window_weight, centralized_pairing, centralized_weight, threshold, iterations, n_tors=80):
    def window_pair_weight(x, y):
        return window_weight[x][y] + window_weight[y][x]
    
    def connect(x, y):
        if last_pairing[x] == y:
            return
        if last_pairing[x] != -1:
            last_pairing[last_pairing[x]] = -1
        if last_pairing[y] != -1:
            last_pairing[last_pairing[y]] = -1
        last_pairing[x] = y
        last_pairing[y] = x
        
    requests = []
    
    def request_connect(x, y):
        requests.append((x, y))
        
    def process_requests():
        random.shuffle(requests)
        updated = set()
        for (x , y) in requests:
            if x in updated or y in updated:
                continue
            connect(x, y)
            updated.add(x)
            updated.add(y)
        
    # This function is being read between each two MWM computations
    for iteration in range(iterations):
        for tor in range(n_tors):
            cur_peer = last_pairing[tor]                                            # get the last (distributed) connection of the specific ToR
            if cur_peer == -1:
                cur_traffic = 0
            elif cur_peer == centralized_pairing[tor]:
                cur_traffic = window_pair_weight(tor, cur_peer)
                if cur_traffic >= threshold * centralized_weight[tor]:
                    continue
            else:
                cur_traffic = window_pair_weight(tor, cur_peer)

            ivector = {}
            for j in range(n_tors):  # going over all the other ToRs
                bi_traffic = window_pair_weight(tor, j)  # compute their weights to and from the specific ToR
                if bi_traffic > cur_traffic:
                    ivector[j] = bi_traffic
            # if DEBUG_ITERATION: # Candidates to replace the traffic
            #    print sorted(ivector.keys(), key = ivector.get, reverse=True)
            for k in sorted(ivector.keys(), key=ivector.get,
                            reverse=True):  # going from the maximal to the lowest weight ToR
                dk = last_pairing[k]  # checking the (distributed) ToR peer of the selected ToR
                if dk == -1:
                    # if the ToR we wanto to connect to does noot have any distributed connection, we can connect with it (can be switched order with the last "if")
                    request_connect(tor, k)
                    break
                if centralized_pairing[k] == dk and window_pair_weight(k, dk) >= threshold * centralized_weight[k]:
                    # we check it tf the current connectuion is made by the centrelized - if so - we must check the threshold
                    # if threshold == 0 - we won't break the cenrelized
                    continue
                if window_pair_weight(k, dk) < window_pair_weight(k, tor):
                    # if the ToR we wanto to connect to does noot have any distributed connection, we can connect with it (can be switched order with the last "if")
                    request_connect(tor, k)
                    break
        process_requests()


def update_distributed_pairing_ex(last_pairing, window_weight, centralized_pairing, centralized_weight, threshold,
                                  iterations, n_tors=80, max_degree=1, use_max_peer_cent_weights=False, **kwargs):
    max_peer_centralized_weight = {
        u: max([centralized_weight[u][v] for v in centralized_pairing[u]] + [0]) for u in centralized_pairing
    }
    
    def window_pair_weight(x, y):
        return window_weight[x][y] + window_weight[y][x]
    
    def is_connected(x, y):
        if x not in last_pairing:
            return False
        return y in last_pairing[x]
    
    def is_protected(x, y):
        if not is_connected(x, y):
            return False
        if x not in centralized_pairing:
            return False
        if y not in centralized_pairing[x]:
            return False
        weight = window_pair_weight(x, y)
        if use_max_peer_cent_weights:
            return weight >= threshold * max([max_peer_centralized_weight[x], max_peer_centralized_weight[y]])
        return weight >= threshold * centralized_weight[x][y]
    
    def get_minimal_unprotected_peer(x):
        if len(last_pairing[x]) < max_degree:
            return None, 0
        min_peer = None
        min_weight = inf
        for peer in last_pairing[x]:
            if is_protected(x, peer):
                continue
            peer_weight = window_pair_weight(x, peer)
            if min_peer is None or peer_weight < min_weight:
                min_peer = peer
                min_weight = peer_weight
        return min_peer, min_weight
    
    def connect(x, y):
        assert x != y
        if y in last_pairing[x]:
            assert x in last_pairing[y]
            return
        xd, xdw = get_minimal_unprotected_peer(x)
        yd, ydw = get_minimal_unprotected_peer(y)
        weight = window_pair_weight(x, y)
        if weight < xdw or weight < ydw:
            return
        if xd is not None:
            last_pairing[xd].remove(x)
            last_pairing[x].remove(xd)
        if yd is not None:
            last_pairing[yd].remove(y)
            last_pairing[y].remove(yd)
        last_pairing[x].add(y)
        last_pairing[y].add(x)
    
    requests = []
    
    def request_connect(x, y):
        requests.append((x, y))
    
    def process_requests():
        random.shuffle(requests)
        for (x, y) in requests:
            connect(x, y)
    
    # This function is being read between each two MWM computations
    for iteration in range(iterations):
        for tor in range(n_tors):
            min_peer, min_weight = get_minimal_unprotected_peer(tor)
            ivector = {}
            free_degree = max_degree
            for j in range(n_tors):  # going over all the other ToRs
                if j == tor:
                    continue
                #if j in last_pairing[tor]:
                #    free_degree -= 1
                #    continue
                if is_protected(tor, j):
                    free_degree -= 1
                    continue
                peer_weight = window_pair_weight(tor, j)
                if peer_weight == 0 or peer_weight < min_weight:
                    continue
                _, peer_min_weight = get_minimal_unprotected_peer(j)
                if peer_weight <= peer_min_weight:
                    continue
                ivector[j] = peer_weight
            for peer in sorted(ivector.keys(), key=ivector.get,
                            reverse=True)[:free_degree]:  # going from the maximal to the lowest weight ToR
                request_connect(tor, peer)
        process_requests()


def update_distributed_pairing_ex_bounded(last_pairing, window_weight, centralized_pairing, centralized_weight,
                                          threshold, iterations, n_tors=80, max_degree=1,
                                          use_max_peer_cent_weights=False, max_reqs=None, **kwargs):
    
    if max_reqs is None:
        return update_distributed_pairing_ex(last_pairing, window_weight, centralized_pairing, centralized_weight,
                                             threshold, iterations, n_tors=n_tors, max_degree=max_degree,
                                             use_max_peer_cent_weights=use_max_peer_cent_weights, **kwargs)
    
    max_peer_centralized_weight = {
        u: max([centralized_weight[u][v] for v in centralized_pairing[u]] + [0]) for u in centralized_pairing
    }
    
    def window_pair_weight(x, y):
        return window_weight[x][y] + window_weight[y][x]
    
    def is_connected(x, y):
        if x not in last_pairing:
            return False
        return y in last_pairing[x]
    
    def is_protected(x, y):
        if not is_connected(x, y):
            return False
        if x not in centralized_pairing:
            return False
        if y not in centralized_pairing[x]:
            return False
        weight = window_pair_weight(x, y)
        if use_max_peer_cent_weights:
            return weight >= threshold * max([max_peer_centralized_weight[x], max_peer_centralized_weight[y]])
        return weight >= threshold * centralized_weight[x][y]
    
    def get_minimal_unprotected_peer(x):
        if len(last_pairing[x]) < max_degree:
            return None, 0
        min_peer = None
        min_weight = inf
        for peer in last_pairing[x]:
            if is_protected(x, peer):
                continue
            peer_weight = window_pair_weight(x, peer)
            if min_peer is None or peer_weight < min_weight:
                min_peer = peer
                min_weight = peer_weight
        return min_peer, min_weight
    
    def disconnect(x, y):
        if x is None or y is None:
            return
        if y not in last_pairing[x]:
            return
        last_pairing[y].remove(x)
        last_pairing[x].remove(y)
        
    def simple_connect(x, y):
        last_pairing[x].add(y)
        last_pairing[y].add(x)
        assert len(last_pairing[x]) <= max_degree
        assert len(last_pairing[y]) <= max_degree
        
    def disconnect_unprotected(x):
        dsts = list(last_pairing[x])
        for y in dsts:
            if not is_protected(x, y):
                disconnect(x, y)
        
    def connect(x, y):
        assert x != y
        if y in last_pairing[x]:
            assert x in last_pairing[y]
            return
        xd, xdw = get_minimal_unprotected_peer(x)
        yd, ydw = get_minimal_unprotected_peer(y)
        weight = window_pair_weight(x, y)
        if weight < xdw or weight < ydw:
            return
        if xd is not None:
            last_pairing[xd].remove(x)
            last_pairing[x].remove(xd)
        if yd is not None:
            last_pairing[yd].remove(y)
            last_pairing[y].remove(yd)
        last_pairing[x].add(y)
        last_pairing[y].add(x)
    
    requests = []
    
    def request_connect(x, y):
        requests.append((x, y))
    
    def process_requests():
        random.shuffle(requests)
        for (x, y) in requests:
            connect(x, y)
    
    # This function is being read between each two MWM computations
    for iteration in range(iterations):
        requested_peers = defaultdict(set)
        free_degrees = {}
        for tor in range(n_tors):
            min_peer, min_weight = get_minimal_unprotected_peer(tor)
            free_degree = max_degree
            peer_weights = {}
            for j in range(n_tors):  # going over all the other ToRs
                if j == tor:
                    continue
                # if j in last_pairing[tor]:
                #    free_degree -= 1
                #    continue
                if is_protected(tor, j):
                    free_degree -= 1
                    continue
                if is_connected(tor, j):  # new fix
                    disconnect(tor, j)
                peer_weight = window_pair_weight(tor, j)
                if peer_weight == 0 or peer_weight < min_weight:
                    continue
                peer_weights[j] = peer_weight
            free_degrees[tor] = free_degree
            potential_peers = sorted(peer_weights.keys(), key=lambda j: peer_weights[j], reverse=True)
            requested_peers[tor] = set(potential_peers[:max_reqs])

        grunted_peers = defaultdict(set)
        for tor in range(n_tors):
            ivector = {}
            free_degree = free_degrees[tor]
            for j in requested_peers[tor]:
                if tor in requested_peers[j]:
                    peer_weight = window_pair_weight(tor, j)
                    ivector[j] = peer_weight
            grunted_peers[tor] = set(sorted(ivector.keys(), key=ivector.get, reverse=True)[:free_degree])
            
        for tor in range(n_tors):
            for j in grunted_peers[tor]:  # going from the maximal to the lowest weight ToR
                if tor in grunted_peers[j]:
                    #request_connect(tor, j)
                    #connect(tor, j)
                    simple_connect(tor, j)  # new fix
        #process_requests()


def compare_matchings(a, b):
    #    print "comparing "
    #    print a
    #    print b
    for i in range(80):
        if a[i] != b[i]:
            print str(i) + " " + str(a[i]) + "->" + str(b[i])


def verify_matching(matching):
    prev = -1
    for a in sorted(matching.values()):
        if a != -1 and a == prev:
            return False
        prev = a
    return True
    
    
class AllZeroRow(object):
    def __getitem__(self, item):
        return 0


class AllZerosMatrix(object):
    zero_row = AllZeroRow()
    
    def __getitem__(self, item):
        return self.zero_row
    
    
zero_matrix = AllZerosMatrix()


class TrafficWindow:
    def __init__(self, length, delay=0):
        self.length = length
        self.delay = delay
        self.sum = None
        self.pending = []
        self.matrices = []
        
    def add(self, matrix):
        if self.delay:
            self.pending.append(matrix)
            if len(self.pending) > self.delay:
                matrix = self.pending.pop(0)
            else:
                return
            
        if len(self.matrices) == self.length:
            out = self.matrices.pop(0)
            subtract_with_matrix(self.sum, out, 1.0 / self.length)
            
        if self.sum is None:
            self.sum = multiply_with_scalar(matrix, 1.0 / self.length)
        else:
            add_with_matrix(self.sum, matrix, 1.0 / self.length)
            
        self.matrices.append(matrix)
        
    def get_sum(self):
        if self.sum is None:
            return zero_matrix
        return self.sum


def pairs_dif(new_pairs, old_pairs, dif_stats):
    total = len(new_pairs)
    change = len(new_pairs ^ old_pairs)
    dif_stats["totals"].add(total)
    dif_stats["changes"].add(change)


def pairing_to_pairs(pairing):
    return {tuple(sorted([k, pairing[k]])) for k in pairing if pairing[k] != -1}


def compute_dist_only_throughput(window=1, iterations=3, win_delay=0,
                       n_milis=5000, n_tors=80, output_dir=".", **kwargs):
    per_mili_pattern = path.join(output_dir, "matrix_mili_%d")
    per_mili_matrix_list = MatrixList(per_mili_pattern)
    
    
    tps = []
    total_tp = 0
    
    trafic_window = TrafficWindow(window, win_delay)  # initialize the time period backward from which the distributed is computed
    distributed_pairing = defaultdict(lambda: -1)  # initialize the current pairs
    centralized_matches = dict(pairing=defaultdict(lambda: -1), weights=defaultdict(int))
    dif_stats = dict(totals=SeriesStats(), changes=SeriesStats())
    old_pairs = pairing_to_pairs(distributed_pairing)
    for t in range(n_milis):
        print "\r", t, "/", n_milis,
        matrix = list(per_mili_matrix_list[t])
        if t != 0:
            update_distributed_pairing(
                last_pairing=distributed_pairing,
                window_weight=trafic_window.get_sum(),
                centralized_pairing=centralized_matches["pairing"],
                centralized_weight=centralized_matches["weights"],  # doesnt matter
                threshold=1,  # doesnt matter
                iterations=iterations,
                n_tors=n_tors
            )
        new_pairs = pairing_to_pairs(distributed_pairing)
        pairs_dif(new_pairs, old_pairs, dif_stats)
        old_pairs = new_pairs
        tp = sum([matrix[x][distributed_pairing[x]] for x in distributed_pairing if
                  distributed_pairing[x] != -1])  # update the total throuhput of every pair
        tps.append(tp)
        total_tp += tp
        trafic_window.add(matrix)  # update the traffic window with the current traffic
    return tps, total_tp, dif_stats


def pairing_to_pairs_ex(pairing):
    return {tuple(sorted([k, j])) for k in pairing for j in pairing[k] if j != -1}


def compute_dist_only_throughput_ex(window=1, iterations=3, win_delay=0,
                                    n_milis=5000, n_tors=80, output_dir=".", **kwargs):
    if kwargs.get("max_degree", 1) == 1:
        return compute_dist_only_throughput(window=window, iterations=iterations, win_delay=win_delay,
                                 n_milis=n_milis, n_tors=n_tors, output_dir=output_dir, **kwargs)
    per_mili_pattern = path.join(output_dir, "matrix_mili_%d")
    per_mili_matrix_list = MatrixList(per_mili_pattern)
    
    tps = []
    total_tp = 0
    
    trafic_window = TrafficWindow(window,
                                  win_delay)  # initialize the time period backward from which the distributed is computed
    distributed_pairing = defaultdict(set)  # initialize the current pairs
    centralized_matches = dict(links=defaultdict(set), weights=None)
    dif_stats = dict(totals=SeriesStats(), changes=SeriesStats())
    old_pairs = pairing_to_pairs_ex(distributed_pairing)
    
    for t in range(n_milis):
        print "\r", t, "/", n_milis,
        matrix = list(per_mili_matrix_list[t])
        if t != 0:
            update_distributed_pairing_ex_bounded(
                last_pairing=distributed_pairing,
                window_weight=trafic_window.get_sum(),
                centralized_pairing=centralized_matches["links"],
                centralized_weight=centralized_matches["weights"],  # doesnt matter
                iterations=iterations,
                n_tors=n_tors,
                **kwargs
            )
        new_pairs = pairing_to_pairs_ex(distributed_pairing)
        pairs_dif(new_pairs, old_pairs, dif_stats)
        old_pairs = new_pairs
        tp = sum([matrix[x][y] for x in distributed_pairing for y in distributed_pairing[x]])
        tps.append(tp)
        total_tp += tp
        trafic_window.add(matrix)  # update the traffic window with the current traffic
    return tps, total_tp, dif_stats


def compute_throughput(compute_epoch=1, agg_interval=1, agg_epoch_delay=0, top=None,
                       window=1, threshold=1, iterations=3, win_delay=0,
                       n_milis=5000, n_tors=80, output_dir=".", **kwargs):
    print dict(compute_epoch=compute_epoch, agg_interval=agg_interval, agg_epoch_delay=agg_epoch_delay)
    per_mili_pattern = path.join(output_dir, "matrix_mili_%d")
    per_mili_matrix_list = MatrixList(per_mili_pattern)
    if top:
        per_mili_match_pattern = path.join(output_dir, "mwm_mili_%d_top_" + str(top)+("_deg_%d" % kwargs.get("max_degree", 1)))
    else:
        per_mili_match_pattern = path.join(output_dir, "mwm_mili_%d"+("_deg_%d" % kwargs.get("max_degree", 1)))
    per_mili_match_list = DictList(per_mili_match_pattern)

    if top:
        per_interval_match_pattern = path.join(output_dir, "mwm_agg_%d_%d-%d_top_" + str(top)+("_deg_%d" % kwargs.get("max_degree", 1)))
    else:
        per_interval_match_pattern = path.join(output_dir, "mwm_agg_%d_%d-%d"+("_deg_%d" % kwargs.get("max_degree", 1)))
    per_interval_match_list = DictList(per_interval_match_pattern)
    
    def get_centralized_matches(t):
        end = t - t % compute_epoch - (agg_epoch_delay - 1) * compute_epoch     # t - t % compute_epoch = the beging of current decission interval
                                                                                # (agg_epoch_delay - 1) * compute_epoch = how many decissions interval we should go back to the MWM interval
                                                                                # Therefore, t - t % compute_epoch - (agg_epoch_delay - 1) * compute_epoch = the end of interval to be considered by MWM
        if end > n_milis:
            end -= compute_epoch
        start = end - agg_interval                                              # The point where the WMW interval starts
        print "start", start, "end", end
        if start < 0:
            mwm = []
            #traffic = make_zero_matrix(n_tors, n_tors)
        elif agg_interval == 1:                                                 # if the MWM is based on 1 millisec
            mwm = per_mili_match_list[start]                                    # reading the MWM of the given mili
            traffic = list(per_mili_matrix_list[start])                         # reading the traffic from given milli matrix
        else:
            mwm = per_interval_match_list[(agg_interval, start, end - 1)]
            traffic = make_zero_matrix(n_tors, n_tors)
            for tt in range(start, end):                                        # every milisec the traffic is accumuilted for the thresh computation next
                matrix = per_mili_matrix_list[tt]
                add_with_matrix(traffic, matrix)
        pairing = defaultdict(lambda : -1)
        pairing.update({m[0]: m[1] for m in mwm})
        pairing.update({m[1]: m[0] for m in mwm})

        mwm_weight = defaultdict(int)
        mwm_weight.update({x: (traffic[x][pairing[x]] + traffic[pairing[x]][x]) / agg_interval for x in pairing})  # for each pair, the traffic is updated
        res = {}
        res["pairing"] = pairing
        res["weights"] = mwm_weight
        #res["last_traffic"] = traffic
        return res
    
    
    tps = []
    total_tp = 0
    
    trafic_window = TrafficWindow(window, win_delay)           # initialize the time period backward from which the distributed is computed
    distributed_pairing = defaultdict(lambda : -1)  # initialize the current pairs
    dif_stats = dict(totals=SeriesStats(), changes=SeriesStats())
    old_pairs = pairing_to_pairs(distributed_pairing)
    for z in range(n_milis / compute_epoch):        # the number of centrelized computations (the total time of the experimennts (in milisec)/ the amount of milisec takng into consideration per computation)
        start = z*compute_epoch                     # the point from which the current computation is (stated) to be considered
        centralized_matches = get_centralized_matches(start)
        distributed_pairing = centralized_matches["pairing"].copy()
        for t in range(z*compute_epoch, (z+1)*compute_epoch): # between each two MWM updates, the dist computation is taking place
            print "\r", t, "/", n_milis,
            matrix = list(per_mili_matrix_list[t])
            if t != 0:  # to prevent immediate change to centralized : z*compute_epoch
                update_distributed_pairing(
                    distributed_pairing,
                    trafic_window.get_sum(),
                    centralized_matches["pairing"],
                    centralized_matches["weights"],
                    threshold,
                    iterations,
                    n_tors
                )
            new_pairs = pairing_to_pairs(distributed_pairing)
            pairs_dif(new_pairs, old_pairs, dif_stats)
            old_pairs = new_pairs
            tp = sum([matrix[x][distributed_pairing[x]] for x in distributed_pairing if distributed_pairing[x] != -1]) #update the total throuhput of every pair
            tps.append(tp)
            total_tp += tp
            trafic_window.add(matrix)           # update the traffic window with the current traffic
    return tps, total_tp, dif_stats


def compute_throughput_ex(compute_epoch=1, agg_interval=1, agg_epoch_delay=0, top=None,
                       window=1, threshold=1, iterations=3, win_delay=0,
                       n_milis=5000, n_tors=80, output_dir=".", **kwargs):
    if kwargs.get("max_degree", 1) == 1:
        return compute_throughput(compute_epoch=compute_epoch, agg_interval=agg_interval, agg_epoch_delay=agg_epoch_delay, top=top,
                                  window=window, iterations=iterations, win_delay=win_delay, threshold=threshold,
                                  n_milis=n_milis, n_tors=n_tors, output_dir=output_dir, **kwargs)
    print dict(compute_epoch=compute_epoch, agg_interval=agg_interval, agg_epoch_delay=agg_epoch_delay)
    per_mili_pattern = path.join(output_dir, "matrix_mili_%d")
    per_mili_matrix_list = MatrixList(per_mili_pattern)
    if top:
        per_mili_match_pattern = path.join(output_dir, "mwm_mili_%d_top_" + str(top)+("_deg_%d" % kwargs.get("max_degree", 1)))
    else:
        per_mili_match_pattern = path.join(output_dir, "mwm_mili_%d"+("_deg_%d" % kwargs.get("max_degree", 1)))
    per_mili_match_list = DictList(per_mili_match_pattern)
    
    if top:
        per_interval_match_pattern = path.join(output_dir, "mwm_agg_%d_%d-%d_top_" + str(top)+("_deg_%d" % kwargs.get("max_degree", 1)))
    else:
        per_interval_match_pattern = path.join(output_dir, "mwm_agg_%d_%d-%d"+("_deg_%d" % kwargs.get("max_degree", 1)))
    per_interval_match_list = DictList(per_interval_match_pattern)
    
    def get_centralized_matches(t):
        end = t - t % compute_epoch - (agg_epoch_delay - 1) * compute_epoch
        # t - t % compute_epoch = the beging of current decission interval
        # (agg_epoch_delay - 1) * compute_epoch = how many decissions interval we should go back to the MWM interval
        # Therefore, t - t % compute_epoch - (agg_epoch_delay - 1) * compute_epoch = the end of interval to be considered by MWM
        if end > n_milis:
            end -= compute_epoch
        start = end - agg_interval  # The point where the WMW interval starts
        print "start", start, "end", end
        if start < 0:
            #mwm = []
            # traffic = make_zero_matrix(n_tors, n_tors)
            return dict(links=defaultdict(set), weights=None)
        elif agg_interval == 1:  # if the MWM is based on 1 millisec
            mwm = per_mili_match_list[start]  # reading the MWM of the given mili
            traffic = list(per_mili_matrix_list[start])  # reading the traffic from given milli matrix
        else:
            mwm = per_interval_match_list[(agg_interval, start, end - 1)]
            traffic = make_zero_matrix(n_tors, n_tors)
            for tt in range(start, end):  # every milisec the traffic is accumuilted for the thresh computation next
                matrix = per_mili_matrix_list[tt]
                add_with_matrix(traffic, matrix)
        links = {} #defaultdict(set)
        mwm_weight = defaultdict(lambda: defaultdict(int))
        for m in mwm:
            x, y = m
            if x not in links:
                links[x] = set()
            if y not in links:
                links[y] = set()
            links[x].add(y)
            links[y].add(x)
            weight = (traffic[x][y] + traffic[y][x]) / agg_interval
            mwm_weight[x][y] = weight
            mwm_weight[y][x] = weight
        res = dict(links=links, weights=mwm_weight)
        # res["last_traffic"] = traffic
        return res
    
    tps = []
    total_tp = 0
    
    trafic_window = TrafficWindow(window,
                                  win_delay)  # initialize the time period backward from which the distributed is computed
    distributed_pairing = defaultdict(set)  # initialize the current pairs
    dif_stats = dict(totals=SeriesStats(), changes=SeriesStats())
    old_pairs = pairing_to_pairs_ex(distributed_pairing)
    for z in range(
            n_milis / compute_epoch):  # the number of centrelized computations (the total time of the experimennts (in milisec)/ the amount of milisec takng into consideration per computation)
        start = z * compute_epoch  # the point from which the current computation is (stated) to be considered
        centralized_matches = get_centralized_matches(start)
        distributed_pairing = defaultdict(set)
        distributed_pairing.update({x: set(centralized_matches["links"][x]) for x in centralized_matches["links"]})
        for t in range(z * compute_epoch,
                       (z + 1) * compute_epoch):  # between each two MWM updates, the dist computation is taking place
            print "\r", t, "/", n_milis,
            matrix = list(per_mili_matrix_list[t])
            if t != 0:  # to prevent immediate change to centralized : z*compute_epoch
                update_distributed_pairing_ex_bounded(
                    last_pairing=distributed_pairing,
                    window_weight=trafic_window.get_sum(),
                    centralized_pairing=centralized_matches["links"],
                    centralized_weight=centralized_matches["weights"],
                    threshold=threshold,
                    iterations=iterations,
                    n_tors=n_tors,
                    **kwargs
                )
            """if distributed_pairing != centralized_matches["links"]:
                dif = {}
                for n in range(n_tors):
                    if distributed_pairing.get(n,set()) != centralized_matches["links"].get(n,set()):
                        in_dist = distributed_pairing[n] - centralized_matches["links"][n]
                        in_cent = centralized_matches["links"][n] - distributed_pairing[n]
                        dif[n] = (in_dist, in_cent)
                print "deviation from centralized"
                """
            new_pairs = pairing_to_pairs_ex(distributed_pairing)
            pairs_dif(new_pairs, old_pairs, dif_stats)
            old_pairs = new_pairs
            tp = sum([matrix[u][v] for u in distributed_pairing for v in distributed_pairing[u]])
            tps.append(tp)
            total_tp += tp
            trafic_window.add(matrix)  # update the traffic window with the current traffic
    return tps, total_tp, dif_stats


def write_results(tps, total_tp, dif_stats,
                  compute_epoch=1, agg_interval=1, agg_epoch_delay=0, top=None,
                  window=1, threshold=1, iterations=3, win_delay=0,
                  n_milis=5000, output_dir=".",
                  max_degree=1, total_load=0, run_id="", flow_avg=0, flow_var=0, **kwargs):
    conf_name = "dist_delay" + str(agg_epoch_delay) + "_epoch"+str(compute_epoch) + "_agg"+str(agg_interval)
    conf_name += "_win"+str(window) + "_t"+str(threshold) + "_i"+str(iterations) + "_wd"+str(win_delay)
    test_file_path = path.join(output_dir, "res_"+str(n_milis)+"_"+conf_name+".json")
    with open(test_file_path, "w") as test_res_file:
        json.dump({"total_tp": total_tp, "tps": tps}, test_res_file)
    all_res_file_path = path.join(output_dir, "res_"+str(n_milis)+".csv")
    if not path.isfile(all_res_file_path):
        with open(all_res_file_path, "w") as all_res_file:
            row = ["mode", "total_tp", "n_milis", "compute_epoch", "agg_interval", "agg_epoch_delay", "top", "max_degree",
                   "links_avg", "links_var", "change_avg", "change_var", "total_load", "run_id", "flow_avg", "flow_var", "tps"]
            all_res_file.write(",".join(map(str, row))+"\n")
    with open(all_res_file_path, "a+") as all_res_file:
        row = [conf_name, total_tp, n_milis, compute_epoch, agg_interval, agg_epoch_delay, top, max_degree,
               dif_stats["totals"].get_avg(), dif_stats["totals"].get_var(),
               dif_stats["changes"].get_avg(), dif_stats["changes"].get_var(),
               total_load, run_id, flow_avg, flow_var] + tps
        all_res_file.write(",".join(map(str, row))+"\n")
        
        
def main():
    parser = argparse.ArgumentParser(
        description="""Compute throughput for each mili [t,t+1) using MWM of interval [start, end) where
    end = t - t%compute_epoch - (agg_epoch_delay-1)*compute_epoch, and
    start = end - agg_interval.
    """,
        epilog="""
    """)
    parser.add_argument('--output_dir', default=".", type=str,
                        help='output directory file (default: conf.json)')
    parser.add_argument('--conf', default="conf.json", type=open,
                        help='configuration file (default: conf.json)')
    # centralized
    parser.add_argument('--compute_epoch', default=1, type=int,
                        help='each compute interval length (default: 1 = mili)') # y
    parser.add_argument('--agg_interval', default=1, type=int,
                        help='each aggregation interval length according to which the MWM is computed (default: 1 = mili)')#z
    parser.add_argument('--agg_epoch_delay', default=0, type=int,
                        help='number of compute epoch delays before aggregation (default: 0 = offline)')
    
    parser.add_argument('--threshold', default=1, type=int,
                        help='the rate precentage from which the distributed optimization is considered, -1 for distributed only (default: 0 = centrelized)')
    parser.add_argument('--window', default=1, type=int,
                help='the length of a sliding window that the algorithm looks into when deciding whether to keep or break an edge in the matching  (default: 1 = mili)')
    parser.add_argument('--iterations', default=3, type=int,
                        help='the number of distributed iteration in per mili decission (default: 3)')
    args = parser.parse_args()
    conf = json.load(args.conf)
    random.seed(conf.get("seed", 1))
    conf["compute_epoch"] = args.compute_epoch
    conf["agg_interval"] = args.agg_interval
    conf["agg_epoch_delay"] = args.agg_epoch_delay
    conf["threshold"] = args.threshold
    conf["window"] = args.window
    conf["iterations"] = args.iterations
    if conf["threshold"] == -1:
        tps, total_tp = compute_dist_only_throughput_ex(**conf)
    else:
        tps, total_tp = compute_throughput_ex(**conf)
    write_results(tps, total_tp, **conf)


if __name__ == "__main__":
    main()

back to top