Revision ba74cf52d14aa13acc33226f458d533c0fcbbacd authored by Tito Dal Canton on 15 April 2021, 14:59:29 UTC, committed by GitHub on 15 April 2021, 14:59:29 UTC
1 parent 2c8dfdb
Raw File
pycbc_splitbank
#!/usr/bin/python
#
# Copyright (C) 2014 LIGO Scientific Collaboration
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

"""
Splits a table in an xml file into multiple pieces
"""

import pycbc, pycbc.version, pycbc.pnutils
__author__  = "Alex Nitz <alex.nitz@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__    = pycbc.version.date
__program__ = "pycbc_splitbank"

import time
import argparse
from glue import gpstime
from glue.ligolw import ligolw
from glue.ligolw import table
from glue.ligolw import lsctables
from glue.ligolw import utils as ligolw_utils
from glue.ligolw.utils import process as ligolw_process
from pycbc import version
from numpy import random, ceil

class ContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(ContentHandler)

 
# Command line parsing
_desc = __doc__[1:]
parser = argparse.ArgumentParser(description=_desc)
parser.add_argument('--version', action='version', version=__version__)

group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--templates-per-bank', metavar='SAMPLES',
                    help='number of templates in the output banks', type=int)
group.add_argument('-n', '--number-of-banks', metavar='N',
                    help='Split template bank into N files', type=int)
group.add_argument("-O", "--output-filenames", nargs='*', default=None,
                    action="store",
                    metavar="OUTPUT_FILENAME", help="""Directly specify the
                    names of the output files. The number of files specified
                    here will dictate how to split the bank. It will be split
                    equally between all specified files.""")

parser.add_argument("-o", "--output-prefix", default=None, 
                    help="Prefix to add to the template bank name (name becomes output#.xml[.gz])" )
parser.add_argument("-z", "--write-compress", action="store_true",
                    help="Write compressed xml.gz files.")

parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False )
parser.add_argument("-t", "--bank-file", metavar='INPUT_FILE',
                    help='Template bank to split', required=True)
parser.add_argument("--sort-frequency-cutoff", 
                    help="Frequency cutoff to use for sorting the sub banks")
parser.add_argument("--sort-mchirp", action="store_true", default=False,
                    help='Sort templates by chirp mass before splitting')
parser.add_argument("--random-sort", action="store_true", default=False,
                    help='Sort templates randomly before splitting')
parser.add_argument("--random-seed", type=int,
                    help='Random seed to use when sorting randomly')

args = parser.parse_args()

if args.output_filenames and args.output_prefix:
    errMsg="Cannot supply --output-filenames with --output-prefix."
    parser.error(errMsg)

if args.sort_mchirp and args.random_sort:
    errMsg="You can't sort by Mchirp *and* randomly, dumbass!"
    parser.error(errMsg)

if args.output_filenames:
    args.number_of_banks = len(args.output_filenames)

indoc = ligolw_utils.load_filename(args.bank_file, verbose=args.verbose,
                                   contenthandler=ContentHandler)

try:
  template_bank_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName)
  tabletype = lsctables.SnglInspiralTable
except:
  template_bank_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)
  tabletype = lsctables.SimInspiralTable

length = len(template_bank_table)

def mchirp_sort(x, y):
    mc1, e1 = pycbc.pnutils.mass1_mass2_to_mchirp_eta(x.mass1, x.mass2)
    mc2, e2 = pycbc.pnutils.mass1_mass2_to_mchirp_eta(y.mass1, y.mass2)
    return cmp(mc1, mc2)

def frequency_cutoff_sort(x, y):
    p1 = pycbc.pnutils.frequency_cutoff_from_name(args.sort_frequency_cutoff,
                                                  x.mass1, x.mass2,
                                                  x.spin1z, x.spin2z)
    p2 = pycbc.pnutils.frequency_cutoff_from_name(args.sort_frequency_cutoff,
                                                  y.mass1, y.mass2,
                                                  y.spin1z, y.spin2z)
    return cmp(p1, p2)

tt = template_bank_table

if args.sort_frequency_cutoff:
    tt = sorted(template_bank_table, cmp=frequency_cutoff_sort)

if args.sort_mchirp:
    tt = sorted(template_bank_table, cmp=mchirp_sort)

if args.random_sort:
    if args.random_seed is not None:
        random.seed(args.random_seed)
    random.shuffle(template_bank_table)

if args.number_of_banks:
    # Decide how to split up the templates
    # Put approximately the requested number of templates in each file
    # But try to make each file very nearly the same size
    num_files = args.number_of_banks
    num_per_file = length / float(num_files)

elif args.templates_per_bank:
    num_per_file = args.templates_per_bank
    num_files = int(ceil(float(length) / num_per_file))

index_list = [int(round(num_per_file*idx)) for idx in range(num_files)]
index_list.append(length)
assert(index_list[0] == 0)

for num, (idx1, idx2) in enumerate(zip(index_list[:-1], index_list[1:])):
    assert(idx2 > idx1)
    # create a blank xml document and add the process id
    outdoc = ligolw.Document()
    outdoc.appendChild(ligolw.LIGO_LW())

    proc_id = ligolw_process.register_to_xmldoc(outdoc, 
                    __program__, args.__dict__, ifos=["G1"],
                    version=version.version, cvs_repository=version.git_branch,
                    cvs_entry_time=version.date).process_id

    sngl_inspiral_table = lsctables.New(tabletype,columns=template_bank_table.columnnames)
    outdoc.childNodes[0].appendChild(sngl_inspiral_table)

    for i in range(idx2-idx1):
       row = tt.pop()
       sngl_inspiral_table.append(row) 
    
    # write the xml doc to disk
    proctable = table.get_table(outdoc, lsctables.ProcessTable.tableName)
    proctable[0].end_time = gpstime.GpsSecondsFromPyUTC(time.time())

    if args.output_filenames:
        outname = args.output_filenames[num]
    elif args.output_prefix:
        outname = args.output_prefix + str(num) + '.xml'
        if args.write_compress:
            outname +='.gz'
    else:
        errMsg = "Cannot figure out how to set output file names."
        raise ValueError(errMsg)
    ligolw_utils.write_filename(outdoc, outname, gz=outname.endswith('.gz'))
back to top