https://github.com/epiqc/ScaffCC
Tip revision: 9d2cca71cf54ddfebda26e247d82ae7b71d9e03c authored by Pranav Gokhale on 30 June 2018, 18:56:21 UTC
Fix OpenQASM output formatting of Rx and Ry
Fix OpenQASM output formatting of Rx and Ry
Tip revision: 9d2cca7
arrange.py
# Pranav Gokhale
# Usage: python arrange.py input_trace_file.tr [width height]
# (if width and height are not specified, default to sidelength
# of smallest square big enough to contain all nodes)
# Input file: .tr file
# Output:
# - placement of nodes on the rectangle
import os, sys, math, re
from os.path import basename
from os.path import splitext
module_num_nodes = 0
def main():
trace_to_graph(sys.argv[1])
mapping_dic = {}
nodes_dic, tl_dic, br_dic = parser() # tl = top left coordinate, br = bottom right coordinate
benchname = splitext(basename(sys.argv[1]))[0] + '.p.' + sys.argv[2] + '.yx.' + sys.argv[3] + '.drop.' + sys.argv[4]
for module_name, nodes in nodes_dic.iteritems():
locations = get_locations(nodes, tl=tl_dic[module_name], br=br_dic[module_name])
os.system('rm -f '+benchname+'.delete_*') # partition function creates many files prefixed with delete_
node_map = {v: k for k, v in locations.items()}
mapping = ''
for y in range(br_dic[module_name].y + 1):
for x in range(br_dic[module_name].x + 1):
if node_map.get(Point(x, y)):
mapping += str(node_map.get(Point(x, y)).index) + ' '
else:
mapping += ' '
mapping += '\n'
mapping_dic[module_name] = mapping
replace(mapping_dic, sys.argv[1])
os.system('rm -f '+benchname+'.delete_*')
def trace_to_graph(infile):
"""Converts .tr file to .graph file."""
assert '.tr' in infile
outfile = infile.replace('.tr', '.graph')
infile = open(infile)
outfile = open(outfile, 'w')
for line in infile:
if line.startswith('module:'):
tokens = line.split(' ')
module_name = tokens[1]
outfile.write('module: %s' % module_name)
continue
elif line.startswith('num_nodes:'):
tokens = line.split(' ')
num_nodes = int(tokens[1])
outfile.write('%s\n' % num_nodes)
continue
elif not line.startswith('ID:'): # skip the header lines
continue
elif 'DST:' not in line: # skip single qubit gates
continue
tokens = line.split(' ')
assert tokens[4].startswith('SRC:')
assert tokens[6].startswith('DST:')
src = int(tokens[5])
dst = int(tokens[7])
assert src < num_nodes
assert dst < num_nodes
outfile.write('%s %s\n' % (src, dst))
def parser():
"""Parses nodes from .graph input file. Returns nodes, br, tl.
Example .graph file:
5
0 1
1 2
2 3
3 4
creates
nodes = [
Node(0, [0, 1, 0, 0, 0])
Node(1, [1, 0, 1, 0, 0])
Node(2, [0, 1, 0, 1, 0])
Node(3, [0, 0, 1, 0, 1])
Node(4, [0, 0, 0, 1, 0])
"""
global module_num_nodes
nodes_dic = {}
tl = {}
br = {}
def _square_size(module_num_nodes):
"""Length of the smallest integer sized square with area of at least module_num_nodes.
(size - 1) * (size - 1) < module_num_nodes <= size * size"""
return int(math.ceil(math.sqrt(module_num_nodes)))
def _add_edge(src, dst):
"""Add an undirected edge."""
module_nodes[src].weights[dst] += 1
module_nodes[dst].weights[src] += 1
f = open(sys.argv[1].replace('.tr', '.graph'))
line = f.readline()
module_name = ''
while line != '':
# head of a module
if line.startswith('module:'):
if module_name != '': # save previous module's info
nodes_dic[module_name] = module_nodes
tl[module_name] = Point(0, 0)
br[module_name] = Point(module_width - 1, module_height - 1)
tokens = line.split(' ')
module_name = tokens[1]
module_num_nodes = int(f.readline())
module_nodes = []
if len(sys.argv) == 4: # if width and height are specified from command line
module_width = int(sys.argv[5])
module_height = int(sys.argv[6])
assert module_width * module_height >= module_num_nodes
else:
square_size = _square_size(module_num_nodes)
module_width, module_height = square_size, square_size
for i in range(module_num_nodes):
n = Node(i, module_num_nodes * [0])
module_nodes.append(n)
# body of a module
else:
tokens = line.split(' ') # expected format of each line is "src dst"
assert len(tokens) == 2
src, dst = int(tokens[0]), int(tokens[1])
_add_edge(src, dst)
line = f.readline()
nodes_dic[module_name] = module_nodes # save last module's info
tl[module_name] = Point(0, 0)
br[module_name] = Point(module_width - 1, module_height - 1)
return nodes_dic, tl, br
class Node(object):
def __init__(self, index, weights):
self.index = index
self.weights = weights
def __repr__(self):
return '%s' % (self.index)
class Point(object):
# direction of coordinate system is as follow:
# 0,0 1,0 2,0
# 0,1 1,1 2,1
# 0,2 1,2 2,2
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return '(%s, %s)' % (self.x, self.y)
def __hash__(self):
return hash((self.x, self.y))
def __eq__(self, another):
return self.x == another.x and self.y == another.y
def get_locations(nodes, tl, br):
"""Returns a mapping from each node in nodes to a location in the tl<->br square.
tl and br are tuples representing (x, y) coordinates of the top left and bottom right.
"""
# Base cases:
if len(nodes) == 1: # for singleton, only choice is to place in the single spot in 1x1 square
return {nodes[0]: tl}
if len(nodes) == 2: # for two nodes, arbitrarily chose to place the first node in top left
return {nodes[0]: tl, nodes[1]: br}
# Recursive case, need to create and solve subproblems:
ret = {}
num_edges = count_num_edges(nodes)
if num_edges == 0: # for empty graphs, no need to run METIS, just assign arbitrarily
i = 0
for x in range(tl.x, br.x+1):
for y in range(tl.y, br.y+1):
if i < len(nodes):
ret.update({nodes[i]: Point(x,y)})
i += 1
return ret
filename = splitext(basename(sys.argv[1]))[0] + '.p.' + sys.argv[2] + '.yx.' + sys.argv[3] + '.drop.' + sys.argv[4] + '.' +\
'_'.join(['delete', str(tl.x), str(tl.y), str(br.x), str(br.y)])
# special case for the very first call of get_locations. For example, suppose that there are
# 97 nodes on a 10x10 grid. Instead of dividing the 97 nodes into 2 equal partitions, we should
# divide them into a partition of 90 nodes and a partition of 7 nodes. The former should be
# placed on a 10x9 grid and te latter should be placed on a 1x7 grid.
if len(nodes) < (br.x - tl.x + 1) * (br.y - tl.y + 1):
assert tl == Point(0, 0)
size_tl_nodes = (br.x + 1) * int(len(nodes) / (br.x + 1))
if size_tl_nodes == len(nodes):
ret.update(get_locations(nodes, tl=Point(0, 0), br=Point(br.x, len(nodes) / (br.x + 1) - 1)))
return ret
nodes_tl, nodes_br = partition(nodes, size_tl_nodes, filename)
# complicated indexing here. As an example, for the 97 into 10x10 case, we want to send 90 nodes
# to a rectangle spanned by tl=Point(0, 0) and br=Point(9, 8) and we want to send 7 nodes to a
# rectangle spanned by tl=Point(0, 9) and br=Point(6, 9)
ret.update(get_locations(nodes_tl, tl=Point(0, 0), br=Point(br.x, len(nodes) / (br.x + 1) - 1)))
ret.update(get_locations(nodes_br, tl=Point(0, len(nodes) / (br.x + 1)), br=Point(len(nodes) % (br.x + 1) - 1, len(nodes) / (br.x + 1))))
return ret
if br.x - tl.x > br.y - tl.y: # if rectangle is wider than tall, split on y axis
half = tl.x + (br.x - tl.x - 1) / 2
size_tl_nodes = (half - tl.x + 1) * (br.y - tl.y + 1)
else: # split on x axis
half = tl.y + (br.y - tl.y - 1) / 2
size_tl_nodes = (br.x - tl.x + 1) * (half - tl.y + 1)
nodes_tl, nodes_br = partition(nodes, size_tl_nodes, filename)
if br.x - tl.x > br.y - tl.y: # if rectangle is wider than tall, split on y axis
ret.update(get_locations(nodes_tl, tl=tl, br=Point(half, br.y)))
ret.update(get_locations(nodes_br, tl=Point(half + 1,tl.y), br=br))
else: # split on x axis
ret.update(get_locations(nodes_tl, tl=tl, br=Point(br.x, half)))
ret.update(get_locations(nodes_br, tl=Point(tl.x, half + 1), br=br))
return ret
def count_num_edges(nodes):
num_edges = 0
for src_node in nodes:
for dst_node in nodes:
if src_node.weights[dst_node.index]:
num_edges += 1
num_edges /= 2 # don't double count edges
return num_edges
def partition(nodes, size_nodes_tl, filename):
"""Returns nodes_tl, nodes_br."""
weight_0 = float(size_nodes_tl) / len(nodes)
num_edges = count_num_edges(nodes)
f = open(filename, 'w')
f.write('%s %s 001\n' % (len(nodes), num_edges))
for src_node in nodes:
x = 1
for dst_node in nodes:
if src_node.weights[dst_node.index]:
f.write('%s %s ' % (x, src_node.weights[dst_node.index]))
x += 1
f.write('\n')
f.close()
benchname = splitext(basename(sys.argv[1]))[0] + '.p.' + sys.argv[2] + '.yx.' + sys.argv[3] + '.drop.' + sys.argv[4]
f = open(benchname+'.delete_tpwgts', 'w') # target partition weights file
f.write('0 = %s' % str(weight_0))
f.close()
os.system('gpmetis -ptype=rb -tpwgts='+benchname+'.delete_tpwgts -ufactor=1 -ncuts=500 ' + filename + ' 2 > /dev/null')
f = open(filename + '.part.2')
nodes_tl, nodes_br = [], []
index = 0
for line in f:
if line == '0\n':
nodes_tl.append(nodes[index])
else:
nodes_br.append(nodes[index])
index += 1
assert len(nodes_tl) + len(nodes_br) == len(nodes), 'nodes went missing somehow :('
# METIS is not perfect and will not always actually yield size_nodes_tl nodes in nodes_tl
# in this case, just pop enough nodes from one list to the other, until sizes are appropriate
# TODO: pick the nodes to transfer between lists more intelligently
while len(nodes_tl) > size_nodes_tl:
nodes_br.append(nodes_tl.pop())
while size_nodes_tl > len(nodes_tl):
nodes_tl.append(nodes_br.pop())
return nodes_tl, nodes_br
def replace(mapping_dic, f_tr):
"""Replace the nodes in the .tr file using the mapping and generate .opt.tr file."""
f_opt_tr = f_tr.replace('.tr', '.opt.tr')
f_tr = open(f_tr)
f_opt_tr = open(f_opt_tr, 'w')
for line in f_tr:
if line.startswith('module: '):
tokens = line.split(' ')
module_name = tokens[1]
mapping = mapping_dic[module_name].split()
d = {}
i = 0
for item in mapping:
d[int(item)] = i
i += 1
if 'SRC' not in line:
f_opt_tr.write(line)
continue
line = re.sub(r'SRC: (\d+)', lambda match: 'SRC: '+str(d.get(int(match.group(1)), 'SOMETHING WENT WRONG' + match.group(1))), line)
line = re.sub(r'DST: (\d+)', lambda match: 'DST: '+str(d.get(int(match.group(1)), 'SOMETHING WENT WRONG' + match.group(1))), line)
f_opt_tr.write(line)
if __name__ == "__main__":
main()