https://github.com/eastzone/atpg
Raw File
Tip revision: 9d7a7fe55529326bf4b734f65c3eef07bd707d3f authored by eastzone on 11 August 2015, 05:09:41 UTC
[license] More missing files
Tip revision: 9d7a7fe
atpg_internet2.py
#!/usr/bin/env python
'''
    <ATPG for Internet2>

@author: James Hongyi Zeng
'''
from utils.load_internet2_backbone import *
from headerspace.applications import *
from headerspace.hs import *
from multiprocessing import Pool, cpu_count
from config_parser.juniper_parser import juniperRouter
import random, time, sqlite3, os, json, socket, struct
from argparse import ArgumentParser

ntf_global = ""
ttf_global = ""
src_port_ids_global = set()
dst_port_ids_global = set()

DATABASE_FILE = "work/internet2.sqlite"
TABLE_NETWORK_RULES = "network_rules"
TABLE_TOPOLOGY_RULES = "topology_rules"
TABLE_RESULT_RULES = "result_rules"
TABLE_TEST_PACKETS = "test_packets"
TABLE_TEST_PACKETS_GLOBALLY_COMPRESSED = "test_packets_globally_compressed"
TABLE_TEST_PACKETS_LOCALLY_COMPRESSED = "test_packets_locally_compressed"

TABLE_SCRATCHPAD = "scratch_pad"
CPU_COUNT = cpu_count()

port_reverse_map_global = {}
port_map_global = {}


def parse_non_wc_field(field, right_wc):
    '''
    right_wc can be True for IP fields and False for non-IP fields. It indicates whether this field should be treated
    as a right-hand masked field or not.
    '''
    values = []
    wildcards = []
    if right_wc:
        found_right_wc = -1
    else:
        found_right_wc = 0
    for i in range (len(field)):
        for j in range (4):
            next_bit = (field[i] >> (2 * j)) & 0x03
            if right_wc and found_right_wc == -1 and next_bit != 0x03:
                # detect when we have scanned all right wildcarded bits
                found_right_wc = j + i * 4 + 1
                values.append(0)
            new_values = []
            for value in values:
                if (next_bit == 0x02 or next_bit == 0x03) and found_right_wc != -1:
                    new_values.append(value + 2 ** (j + i * 4))
                if (next_bit == 0x01 or next_bit == 0x03) and found_right_wc != -1:
                    new_values.append(value)
            values = new_values
    
    return [values, found_right_wc]

def parse_normal_field(field, right_wc):
    '''
    right_wc can be True for IP fields and False for non-IP fields. It indicates whether this field should be treated
    as a right-hand masked field or not.
    '''
    values = [0]
    wildcards = []
    if right_wc:
        found_right_wc = -1
    else:
        found_right_wc = 0
    for i in range (len(field)):
        for j in range (4):
            next_bit = (field[i] >> (2 * j)) & 0x03
            if right_wc and found_right_wc == -1 and next_bit != 0x03:
                # detect when we have scanned all right wildcarded bits
                found_right_wc = j + i * 4 + 1
                values.append(0)
            new_values = []
            for value in values:
                if (next_bit == 0x02 or next_bit == 0x03) and found_right_wc != -1:
                    new_values.append(value + 2 ** (j + i * 4))
                if (next_bit == 0x01 or next_bit == 0x03) and found_right_wc != -1:
                    new_values.append(value)
            values = new_values
    
    return [values, found_right_wc]

def parse_hs(hs_format, hs):
    
    match = hs
    
    fields = ["mac_src", "mac_dst", "vlan", "ip_src", "ip_dst", "ip_proto", "transport_src", "transport_dst"]
    openflow_entry = {}
    for field in fields:
        if "%s_pos" % field not in hs_format.keys():
            continue
        
        position = hs_format["%s_pos" % field]
        len = hs_format["%s_len"%field]
        wildcarded = True
        field_match = bytearray()
        for i in range(2 * len):
            field_match.append(match[position * 2 + i])
            if match[position * 2 + i] != 0xff:
                wildcarded = False

        if wildcarded:
            if field == "ip_src" or field == "ip_dst":
                openflow_entry["%s_wc" % field] = 32
            else:
                openflow_entry["%s_wc" % field] = 1
            openflow_entry["%s_match" % field] = [0]
        else:
            if field == "ip_src" or field == "ip_dst":
                parsed = parse_non_wc_field(field_match, True)
                if parsed[0] != []:
                    parsed[0][0] = socket.inet_ntoa(struct.pack('!L',parsed[0][0]))
            else:
                parsed = parse_normal_field(field_match, False)
            openflow_entry["%s_wc" % field] = parsed[1]
            openflow_entry["%s_match" % field] = parsed[0]
    
    return openflow_entry

def find_reachability_test(NTF, TTF, in_port, out_ports, input_pkt):
    paths = []
    propagation = []
 
    p_node = {}
    p_node["hdr"] = input_pkt
    p_node["port"] = in_port
    p_node["visits"] = []
    #p_node["hs_history"] = []
    propagation.append(p_node)
    #loop_count = 0
    while len(propagation) > 0:
        #get the next node in propagation graph and apply it to NTF and TTF
        #print "Propagation has length: %d"%len(propagation)
        tmp_propagate = []
        for p_node in propagation:
            next_hp = NTF.T(p_node["hdr"], p_node["port"])
            for (next_h, next_ps) in next_hp:            
                for next_p in next_ps:
                    new_p_node = {}
                    new_p_node["hdr"] = next_h
                    new_p_node["port"] = next_p
                    new_p_node["visits"] = list(p_node["visits"])
                    new_p_node["visits"].append(p_node["port"])
                    #new_p_node["hs_history"] = list(p_node["hs_history"])
                  
                    # Reached an edge port
                    if next_p in out_ports:
                        paths.append(new_p_node)
                        
                    linked = TTF.T(next_h, next_p)
                    
                    for (linked_h, linked_ports) in linked:
                        for linked_p in linked_ports:
                            new_p_node = {}
                            new_p_node["hdr"] = linked_h
                            new_p_node["port"] = linked_p
                            new_p_node["visits"] = list(p_node["visits"])
                            new_p_node["visits"].append(p_node["port"])
                            #new_p_node["hs_history"] = list(p_node["hs_history"])
                            #new_p_node["hs_history"].append(p_node["hdr"])
                            if linked_p not in new_p_node["visits"]:
                                tmp_propagate.append(new_p_node)
                                
        propagation = tmp_propagate
                
    return paths

def print_paths_to_database(paths, reverse_map, table_name):
    # Timeout = 6000s
    
    insert_string = "INSERT INTO %s VALUES (?, ?, ?, ?, ?, ?, ?)" % table_name
    
    queries = []
    for p_node in paths:
        path_string = ""
        for port in p_node["visits"]:
            path_string += ("%d " % port)
        path_string += ("%d " % p_node["port"])
        port_count = len(p_node["visits"]) + 1
        
        rl_id = ""
        for (n, r, s) in p_node["hdr"].applied_rule_ids:
            rl_id += (r + " ")
        rule_count = len(p_node["hdr"].applied_rule_ids)
        
        input_port = p_node["visits"][0]
        output_port = p_node["port"]
        output_hs = p_node["hdr"].copy()
        applied_rule_ids = list(output_hs.applied_rule_ids)
        input_hs = trace_hs_back(applied_rule_ids, output_hs, output_port)[0]
        header_string = json.dumps(parse_hs(juniperRouter(1).hs_format, input_hs.hs_list[0]))
        
        #header_string = byte_array_to_pretty_hs_string(input_hs.hs_list[0])
        queries.append((header_string, input_port, output_port, path_string, port_count, rl_id, rule_count))
    
    conn = sqlite3.connect(DATABASE_FILE, 6000)
    for query in queries:    
        conn.execute(insert_string, query)
        
    conn.commit()
    conn.close()

def path_compress(paths):
    ''' Compress Paths using Greedy Algorithm
    An implementation based on Min-Set-Cover
    '''
    result_paths = []
    exercised_rules = set()
    random.shuffle(paths)
    for p_node in paths:
        new_rule = False
        for (n, r, s) in p_node["hdr"].applied_rule_ids:
            if r not in exercised_rules:
                new_rule = True
                break
        if new_rule:
            result_paths.append(p_node)
            for (n, r, s) in p_node["hdr"].applied_rule_ids:
                exercised_rules.add(r)
    
    return result_paths

def rule_lists_compress(rule_lists):   
    st = time.time()
    
    rule_ids_set = set()
    for rule_list in rule_lists:
        rule_ids_set |= set(rule_list)
    
    #print "Reachable Rules: %d" % len(rule_ids_set)
    start_packets = len(rule_lists)
    result_rule_lists = []
    while(len(rule_ids_set) > 0):
        lucky_index = random.randint(0, len(rule_lists)-1)
        rule_list = rule_lists[lucky_index]
        for r in rule_list:
            if r in rule_ids_set:
                result_rule_lists.append(rule_list)
        
                # Rules that have been hit already
                rule_ids_set -= set(rule_list)
                del rule_lists[lucky_index]                
                break
    
    end_packets = len(result_rule_lists)
    
    en = time.time()
    
    print "Global Compression: Start=%d, End=%d, Ratio=%f, Time=%f" % (start_packets, end_packets, float(end_packets)/start_packets, en-st)
    print_rule_lists_to_database(result_rule_lists, TABLE_SCRATCHPAD)

def find_test_packets(src_port_id):

    # Generate All-X packet
    all_x = byte_array_get_all_x(ntf_global.length)
    test_pkt = headerspace(ntf_global.length)
    test_pkt.add_hs(all_x)
       
    st = time.time()
    paths = find_reachability_test(ntf_global, ttf_global, src_port_id, dst_port_ids_global, test_pkt)
    en = time.time()
    
    print_paths_to_database(paths, port_reverse_map_global, TABLE_TEST_PACKETS)
    result_string = "Port:%d, Path No:%d, Time: %fs" % (src_port_id, len(paths), en - st)    
    print result_string

    # Compress
    st = time.time()
    paths = path_compress(paths)
    en = time.time()

    result_string = "Port:%d, Compressed Path No:%d, Time: %fs" % (src_port_id, len(paths), en - st)    
    print result_string
    
    print_paths_to_database(paths, port_reverse_map_global, TABLE_TEST_PACKETS_LOCALLY_COMPRESSED)

    return len(paths)

def chunks(l, n):
    """ Yield successive n chunks from l.
    """
    sub_list_length = len(l) / n        
    if sub_list_length == 0:
        sub_list_length = len(l)    
    return [l[i:i+sub_list_length] for i in range(0, len(l), sub_list_length)]

def merge_chunks(chunks):
    result = sum(chunks, [])
    return result

def print_rule_lists_to_database(result_rule_lists, table_name):
    conn = sqlite3.connect(DATABASE_FILE, 6000)
    query = "INSERT INTO %s VALUES (?, ?)" % table_name
   
    for rule_list in result_rule_lists:
        conn.execute(query, (" ".join(rule_list), len(rule_list)))
     
    conn.commit()    
    conn.close()
    
def read_rule_lists_from_database(table_name):
    result_rule_lists = []
    conn = sqlite3.connect(DATABASE_FILE, 6000)
    query = "SELECT rules FROM %s"  % TABLE_SCRATCHPAD
    rows = conn.execute(query)
    
    for row in rows:
        result_rule_lists.append(row[0].split())
    conn.close()
    return result_rule_lists

def main():  
    global src_port_ids_global
    global dst_port_ids_global
    global port_map_global
    global port_reverse_map_global
    global ntf_global
    global ttf_global
    global DATABASE_FILE
    
    parser = ArgumentParser(description="Generate Test Packets for Internet2")
    parser.add_argument("-p", dest="percentage", type=int,
                      default="100",
                      help="Percentage of test terminals")
    parser.add_argument("-f", dest="filename",
                      default="internet2.sqlite",
                      help="Filename of the database")
    parser.add_argument("-e", action="store_true",
                      default=False,
                      help="Edge port only")
    args = parser.parse_args()
    
    DATABASE_FILE = "work/%s" % args.filename
     
    cs = juniperRouter(1)
    output_port_addition = cs.PORT_TYPE_MULTIPLIER * cs.OUTPUT_PORT_TYPE_CONST
     
    # Load .tf files
    ntf_global = load_internet2_backbone_ntf()
    ttf_global = load_internet2_backbone_ttf()
    (port_map_global, port_reverse_map_global) = load_internet2_backbone_port_to_id_map()
    
    # Initialize the database
    if os.access(DATABASE_FILE, os.F_OK):
        os.remove(DATABASE_FILE)
    
    conn = sqlite3.connect(DATABASE_FILE)
    conn.execute('CREATE TABLE %s (rule TEXT, input_port TEXT, output_port TEXT, action TEXT, file TEXT, line TEXT)' % TABLE_NETWORK_RULES)
    conn.execute('CREATE TABLE %s (rule TEXT, input_port TEXT, output_port TEXT)' % TABLE_TOPOLOGY_RULES)
    conn.execute('CREATE TABLE %s (header TEXT, input_port INTEGER, output_port INTEGER, ports TEXT, no_of_ports INTEGER, rules TEXT, no_of_rules INTEGER)' % TABLE_TEST_PACKETS)
    conn.execute('CREATE TABLE %s (header TEXT, input_port INTEGER, output_port INTEGER, ports TEXT, no_of_ports INTEGER, rules TEXT, no_of_rules INTEGER)' % TABLE_TEST_PACKETS_LOCALLY_COMPRESSED)
    conn.execute('CREATE TABLE %s (rules TEXT, no_of_rules INTEGER)' % TABLE_TEST_PACKETS_GLOBALLY_COMPRESSED)
    conn.execute('CREATE TABLE %s (rule TEXT)' % TABLE_RESULT_RULES)

    rule_count = 0
    for tf in ntf_global.tf_list:
        rule_count += len(tf.rules)
        for rule in tf.rules:
            query = "INSERT INTO %s VALUES (?, ?, ?, ?, ?, ?)" % TABLE_NETWORK_RULES
            conn.execute(query, (rule['id'],' '.join(map(str, rule['in_ports'])), ' '.join(map(str, rule['out_ports'])), rule['action'], rule["file"], ' '.join(map(str, rule["line"]))))
    print "Total Rules: %d" % rule_count
    conn.commit()
    
    rule_count = len(ttf_global.rules) 
    for rule in ttf_global.rules:
        query = "INSERT INTO %s VALUES (?, ?, ?)" % TABLE_TOPOLOGY_RULES 
        conn.execute(query, (rule['id'],' '.join(map(str, rule['in_ports'])), ' '.join(map(str, rule['out_ports']))))  
    print "Total Links: %d" % rule_count
   
    # Generate all ports
    for rtr in port_map_global.keys():
        src_port_ids_global |= set(port_map_global[rtr].values())
    
    
    total_length = len(src_port_ids_global)
    if args.e == True:
        for rule in ttf_global.rules:
            if rule['out_ports'][0] in src_port_ids_global:
                src_port_ids_global.remove(rule['out_ports'][0])    
    
    new_length = len(src_port_ids_global)* args.percentage / 100
    src_port_ids_global = random.sample(src_port_ids_global, new_length)
    print "Total Length: %d" % total_length
    print "New Length: %d" % new_length
    
    for port in src_port_ids_global:
        port += output_port_addition
        dst_port_ids_global.add(port)
    
    #src_port_ids_global = [300013]
    #dst_port_ids_global = [320010]
    
    conn.commit()
    conn.close()
    
    # Run reachability
    start_time = time.time()
    
    pool = Pool()
    result = pool.map_async(find_test_packets, src_port_ids_global)

    # Close
    pool.close()
    pool.join()
    
    end_time = time.time()
    
    test_packet_count = result.get()
    total_paths = sum(test_packet_count)    
    print "========== Before Compression ========="
    print "Total Paths = %d" % total_paths
    print "Average packets per port = %f" % (float(total_paths) / len(src_port_ids_global))
    print "Total Time = %fs" % (end_time - start_time)
    
    #Global Compressing 
    start_time = time.time()
       
    conn = sqlite3.connect(DATABASE_FILE, 6000)    
    result_rule_lists = []
    query = "SELECT rules FROM %s"  % TABLE_TEST_PACKETS_LOCALLY_COMPRESSED
    rows = conn.execute(query)

    for row in rows:
        result_rule_lists.append(row[0].split())
    conn.close()
  
    chunk_size = 80000
    while(True):
        print "Start a new round!"
        conn = sqlite3.connect(DATABASE_FILE, 6000)
        conn.execute('DROP TABLE IF EXISTS %s' % TABLE_SCRATCHPAD)
        conn.execute('CREATE TABLE %s (rules TEXT, no_of_rules INTEGER)' % TABLE_SCRATCHPAD)
        conn.commit()    
        conn.close()
        
        start_len = len(result_rule_lists)
        print start_len
        
        pool = Pool()        
        no_of_chunks = len(result_rule_lists) / chunk_size + 1      
        rule_list_chunks = chunks(result_rule_lists, no_of_chunks)            
        result = pool.map_async(rule_lists_compress, rule_list_chunks)

        # Close
        pool.close()
        pool.join()
        result.get()
        
        print "End of this round."
        
        result_rule_lists = read_rule_lists_from_database(TABLE_SCRATCHPAD)
        
        end_len = len(result_rule_lists)
        if(float(end_len) / float(start_len) > 0.99):
            break

    end_time = time.time()
    
    query = "INSERT INTO %s VALUES (?, ?)" % TABLE_TEST_PACKETS_GLOBALLY_COMPRESSED
    query2 = "INSERT INTO %s VALUES (?)" % TABLE_RESULT_RULES
    
    total_paths = len(result_rule_lists)
    total_length = 0
    
    conn = sqlite3.connect(DATABASE_FILE, 6000)
    conn.execute('DROP TABLE IF EXISTS %s' % TABLE_TEST_PACKETS_GLOBALLY_COMPRESSED)
    conn.execute('CREATE TABLE %s (rules TEXT, no_of_rules INTEGER)' % TABLE_TEST_PACKETS_GLOBALLY_COMPRESSED)

    for rule_list in result_rule_lists:
        total_length += len(rule_list)
        conn.execute(query, (" ".join(rule_list), len(rule_list)))
        for rule in rule_list:
            conn.execute(query2, (rule,))
     
    conn.commit()    
    conn.close()
    
    print "========== After Compression ========="
    print "Total Paths = %d" % total_paths
    print "Average packets per port = %f" % (float(total_paths) / len(src_port_ids_global))
    print "Average length of rule list = %f" % (float(total_length) / total_paths)
    print "Total Time = %fs" % (end_time - start_time)
    
if __name__ == "__main__":
    main()
back to top