Raw File
# Pranav Gokhale
# Usage: python arrange.py input_trace_file.tr [width height]
# (if width and height are not specified, default to sidelength
#  of smallest square big enough to contain all nodes)
# Input file: .tr file
# Output:
#    - placement of nodes on the rectangle

import os, sys, math, re
from os.path import basename
from os.path import splitext

module_num_nodes = 0


def main():
  trace_to_graph(sys.argv[1])
  mapping_dic = {}
  nodes_dic, tl_dic, br_dic = parser()  # tl = top left coordinate, br = bottom right coordinate
  benchname = splitext(basename(sys.argv[1]))[0] + '.p.' + sys.argv[2] + '.yx.' + sys.argv[3] + '.drop.' + sys.argv[4]  
  for module_name, nodes in nodes_dic.iteritems():
    locations = get_locations(nodes, tl=tl_dic[module_name], br=br_dic[module_name])
    os.system('rm -f '+benchname+'.delete_*')  # partition function creates many files prefixed with delete_
    node_map = {v: k for k, v in locations.items()}
    mapping = ''
    for y in range(br_dic[module_name].y + 1):
      for x in range(br_dic[module_name].x + 1):
        if node_map.get(Point(x, y)):
          mapping += str(node_map.get(Point(x, y)).index) + ' '
        else:
          mapping += ' '
      mapping += '\n'
    mapping_dic[module_name] = mapping
  replace(mapping_dic, sys.argv[1])
  os.system('rm -f '+benchname+'.delete_*')


def trace_to_graph(infile):
  """Converts .tr file to .graph file."""
  assert '.tr' in infile
  outfile = infile.replace('.tr', '.graph')

  infile = open(infile)
  outfile = open(outfile, 'w')

  for line in infile:
    if line.startswith('module:'):
      tokens = line.split(' ')
      module_name = tokens[1]
      outfile.write('module: %s' % module_name)
      continue
    elif line.startswith('num_nodes:'):
      tokens = line.split(' ')
      num_nodes = int(tokens[1])
      outfile.write('%s\n' % num_nodes)
      continue
    elif not line.startswith('ID:'):  # skip the header lines
      continue
    elif 'DST:' not in line:          # skip single qubit gates
      continue

    tokens = line.split(' ')
    assert tokens[4].startswith('SRC:')
    assert tokens[6].startswith('DST:')
    src = int(tokens[5])
    dst = int(tokens[7])
    assert src < num_nodes
    assert dst < num_nodes
    outfile.write('%s %s\n' % (src, dst))

def parser():
  """Parses nodes from .graph input file. Returns nodes, br, tl.

    Example .graph file:
    5
    0 1
    1 2
    2 3
    3 4

    creates
    nodes = [
    Node(0, [0, 1, 0, 0, 0])
    Node(1, [1, 0, 1, 0, 0])
    Node(2, [0, 1, 0, 1, 0])
    Node(3, [0, 0, 1, 0, 1])
    Node(4, [0, 0, 0, 1, 0])
    """
  global module_num_nodes
  nodes_dic = {}
  tl = {}
  br = {}
  
  def _square_size(module_num_nodes):
    """Length of the smallest integer sized square with area of at least module_num_nodes.

      (size - 1) * (size - 1) < module_num_nodes <= size * size"""
    return int(math.ceil(math.sqrt(module_num_nodes)))

  def _add_edge(src, dst):
    """Add an undirected edge."""
    module_nodes[src].weights[dst] += 1
    module_nodes[dst].weights[src] += 1

  f = open(sys.argv[1].replace('.tr', '.graph'))
  line = f.readline()

  module_name = ''
  while line != '':
    # head of a module
    if line.startswith('module:'):
      if module_name != '':        # save previous module's info
        nodes_dic[module_name] = module_nodes
        tl[module_name] = Point(0, 0)
        br[module_name] = Point(module_width - 1, module_height - 1)
    
      tokens = line.split(' ')
      module_name = tokens[1]
      module_num_nodes = int(f.readline())
      module_nodes = []

      if len(sys.argv) == 4:        # if width and height are specified from command line
        module_width = int(sys.argv[5])
        module_height = int(sys.argv[6])
        assert module_width * module_height >= module_num_nodes
      else:
        square_size = _square_size(module_num_nodes)
        module_width, module_height = square_size, square_size

      for i in range(module_num_nodes):
        n = Node(i, module_num_nodes * [0])
        module_nodes.append(n)

    # body of a module
    else:
      tokens = line.split(' ')      # expected format of each line is "src dst"
      assert len(tokens) == 2
      src, dst = int(tokens[0]), int(tokens[1])
      _add_edge(src, dst)

    line = f.readline()  

  nodes_dic[module_name] = module_nodes # save last module's info
  tl[module_name] = Point(0, 0)
  br[module_name] = Point(module_width - 1, module_height - 1)

  return nodes_dic, tl, br

class Node(object):
  def __init__(self, index, weights):
    self.index = index
    self.weights = weights
  
  def __repr__(self):
    return '%s' % (self.index)


class Point(object):
  # direction of coordinate system is as follow:
    # 0,0  1,0  2,0
    # 0,1  1,1  2,1
    # 0,2  1,2  2,2
  def __init__(self, x, y):
    self.x = x
    self.y = y

  def __repr__(self):
    return '(%s, %s)' % (self.x, self.y)

  def __hash__(self):
    return hash((self.x, self.y))

  def __eq__(self, another):
    return self.x == another.x and self.y == another.y


def get_locations(nodes, tl, br):
  """Returns a mapping from each node in nodes to a location in the tl<->br square.

  tl and br are tuples representing (x, y) coordinates of the top left and bottom right.
  """
  
  # Base cases:
  if len(nodes) == 1:  # for singleton, only choice is to place in the single spot in 1x1 square
    return {nodes[0]: tl}
  if len(nodes) == 2:  # for two nodes, arbitrarily chose to place the first node in top left
    return {nodes[0]: tl, nodes[1]: br}

  # Recursive case, need to create and solve subproblems:
  ret = {}

  num_edges = count_num_edges(nodes)
  if num_edges == 0:  # for empty graphs, no need to run METIS, just assign arbitrarily
    i = 0
    for x in range(tl.x, br.x+1): 
      for y in range(tl.y, br.y+1):
        if i < len(nodes):
          ret.update({nodes[i]: Point(x,y)})
          i += 1
    return ret

  filename = splitext(basename(sys.argv[1]))[0] + '.p.' + sys.argv[2] + '.yx.' + sys.argv[3] + '.drop.' + sys.argv[4] + '.' +\
      '_'.join(['delete', str(tl.x), str(tl.y), str(br.x), str(br.y)])  

  # special case for the very first call of get_locations. For example, suppose that there are
  # 97 nodes on a 10x10 grid. Instead of dividing the 97 nodes into 2 equal partitions, we should
  # divide them into a partition of 90 nodes and a partition of 7 nodes. The former should be
  # placed on a 10x9 grid and te latter should be placed on a 1x7 grid.
  if len(nodes) < (br.x - tl.x + 1) * (br.y - tl.y + 1):
    assert tl == Point(0, 0)
    size_tl_nodes = (br.x + 1) * int(len(nodes) / (br.x + 1))
    if size_tl_nodes == len(nodes):
      ret.update(get_locations(nodes, tl=Point(0, 0), br=Point(br.x, len(nodes) / (br.x + 1) - 1)))
      return ret

    nodes_tl, nodes_br = partition(nodes, size_tl_nodes, filename)
    # complicated indexing here. As an example, for the 97 into 10x10 case, we want to send 90 nodes
    # to a rectangle spanned by tl=Point(0, 0) and br=Point(9, 8) and we want to send 7 nodes to a 
    # rectangle spanned by tl=Point(0, 9) and br=Point(6, 9)
    ret.update(get_locations(nodes_tl, tl=Point(0, 0), br=Point(br.x, len(nodes) / (br.x + 1) - 1)))
    ret.update(get_locations(nodes_br, tl=Point(0, len(nodes) / (br.x + 1)), br=Point(len(nodes) % (br.x + 1) - 1, len(nodes) / (br.x + 1))))
    return ret

  if br.x - tl.x > br.y - tl.y:  # if rectangle is wider than tall, split on y axis
    half = tl.x + (br.x - tl.x - 1) / 2
    size_tl_nodes = (half - tl.x + 1) * (br.y - tl.y + 1)
  else:  # split on x axis
    half = tl.y + (br.y - tl.y - 1) / 2
    size_tl_nodes = (br.x - tl.x + 1) * (half - tl.y + 1)

  nodes_tl, nodes_br = partition(nodes, size_tl_nodes, filename)

  if br.x - tl.x > br.y - tl.y:  # if rectangle is wider than tall, split on y axis
    ret.update(get_locations(nodes_tl, tl=tl, br=Point(half, br.y)))
    ret.update(get_locations(nodes_br, tl=Point(half + 1,tl.y), br=br))
  else:  # split on x axis
    ret.update(get_locations(nodes_tl, tl=tl, br=Point(br.x, half)))
    ret.update(get_locations(nodes_br, tl=Point(tl.x, half + 1), br=br))

  return ret


def count_num_edges(nodes):
  num_edges = 0
  for src_node in nodes:
    for dst_node in nodes:
      if src_node.weights[dst_node.index]:
        num_edges += 1
  num_edges /= 2  # don't double count edges
  
  return num_edges


def partition(nodes, size_nodes_tl, filename):
  """Returns nodes_tl, nodes_br."""
  weight_0 = float(size_nodes_tl) / len(nodes)
  num_edges = count_num_edges(nodes)

  f = open(filename, 'w')
  f.write('%s %s 001\n' % (len(nodes), num_edges))
  for src_node in nodes:
    x = 1
    for dst_node in nodes:
      if src_node.weights[dst_node.index]:
        f.write('%s %s ' % (x, src_node.weights[dst_node.index]))
      x += 1
    f.write('\n')
  f.close()

  benchname = splitext(basename(sys.argv[1]))[0] + '.p.' + sys.argv[2] + '.yx.' + sys.argv[3] + '.drop.' + sys.argv[4]  
  f = open(benchname+'.delete_tpwgts', 'w')  # target partition weights file
  f.write('0 = %s' % str(weight_0))
  f.close()

  os.system('gpmetis -ptype=rb -tpwgts='+benchname+'.delete_tpwgts -ufactor=1 -ncuts=500 ' + filename + ' 2 > /dev/null')

  f = open(filename + '.part.2')
  nodes_tl, nodes_br = [], []
  index = 0
  for line in f:
    if line == '0\n':
      nodes_tl.append(nodes[index])
    else:
      nodes_br.append(nodes[index])
    index += 1

  assert len(nodes_tl) + len(nodes_br) == len(nodes), 'nodes went missing somehow :('

  # METIS is not perfect and will not always actually yield size_nodes_tl nodes in nodes_tl
  # in this case, just pop enough nodes from one list to the other, until sizes are appropriate
  # TODO: pick the nodes to transfer between lists more intelligently
  while len(nodes_tl) > size_nodes_tl:
    nodes_br.append(nodes_tl.pop())
  while size_nodes_tl > len(nodes_tl):
    nodes_tl.append(nodes_br.pop())

  return nodes_tl, nodes_br


def replace(mapping_dic, f_tr):
  """Replace the nodes in the .tr file using the mapping and generate .opt.tr file."""
  f_opt_tr = f_tr.replace('.tr', '.opt.tr')
  f_tr = open(f_tr)
  f_opt_tr = open(f_opt_tr, 'w')
  for line in f_tr:
    if line.startswith('module: '):
      tokens = line.split(' ')
      module_name = tokens[1]
      mapping = mapping_dic[module_name].split()
      d = {}
      i = 0
      for item in mapping:
        d[int(item)] = i
        i += 1
    if 'SRC' not in line:
      f_opt_tr.write(line)
      continue
    line =  re.sub(r'SRC: (\d+)', lambda match: 'SRC: '+str(d.get(int(match.group(1)), 'SOMETHING WENT WRONG' + match.group(1))), line)
    line =  re.sub(r'DST: (\d+)', lambda match: 'DST: '+str(d.get(int(match.group(1)), 'SOMETHING WENT WRONG' + match.group(1))), line)
    f_opt_tr.write(line)


if __name__ == "__main__":
  main()
back to top