Revision 095105451a22dd5a59c832656163a2efa462acb4 authored by Duncan Brown on 25 September 2018, 15:06:18 UTC, committed by Collin Capano on 26 October 2018, 14:35:14 UTC
1 parent cf96817
Raw File
pycbc_hdf5_splitbank
#!/usr/bin/env python

# Copyright (C) 2016  Soumi De
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
The code reads in a compressed template bank and splits it up into
smaller banks where the number of smaller banks is a user input
"""

import argparse
import numpy
import h5py
import logging
import pycbc, pycbc.version
from pycbc.waveform import bank
from numpy import random

__author__  = "Soumi De <soumi.de@ligo.org>"

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument("--version", action="version",
                  version=pycbc.version.git_verbose_msg)
parser.add_argument("--bank-file", type=str,
                    help="Bank hdf file to load.")
outbanks = parser.add_mutually_exclusive_group(required=True)
outbanks.add_argument("--templates-per-bank", type=int,
                      help="Number of templates in each output sub-banks. "
                      "Either specify this or --number-of-banks, not both.")
outbanks.add_argument("--number-of-banks", type=int,
                      help="Number of output sub-banks. Either specify this "
                      "or --templates-per-bank, not both.")
outbanks.add_argument("--output-filenames", nargs='*', default=None,
                      action="store",
                      help="Directly specify the names of the output files. "
                      "The bank will be split equally between files.")
parser.add_argument("--output-prefix", default=None,
                    help="Prefix to add to the output template bank names, "
                    "for example 'H1L1-BANK'. Output file names would then be "
                    "'H1L1-BANK{x}.hdf' where {x} is 1,2,...")
sortopt = parser.add_mutually_exclusive_group()
sortopt.add_argument("--mchirp-sort", action="store_true", default=False,
                    help="Sort templates by chirp mass before splitting")
sortopt.add_argument("--random-sort", action="store_true", default=False,
                    help="Sort templates randomly before splitting")
parser.add_argument("--random-seed", type=int,
                    help="Random seed for --random-sort")
parser.add_argument("--force", action="store_true", default=False,
                    help="Overwrite the given hdf file if it exists. "
                         "Otherwise, an error is raised.")
parser.add_argument("--verbose", action="store_true", default=False)


args = parser.parse_args()

# input checks
if args.mchirp_sort and (args.random_seed is not None):
    raise RuntimeError("Can't supply a random seed if not sorting randomly!")

if args.output_filenames is None and args.output_prefix is None:
    raise RuntimeError("Must specify either output filenames or a prefix!")

if args.output_filenames and args.output_prefix:
    raise RuntimeError("Can't specify both output filenames and a prefix")

pycbc.init_logging(args.verbose)
logging.info("Loading bank")

tmplt_bank = bank.TemplateBank(args.bank_file)

templates = tmplt_bank.table

if args.random_sort:
    if args.random_seed is not None:
        random.seed(args.random_seed)
    idx = numpy.arange(templates.size)
    numpy.random.shuffle(idx)
    templates = templates[idx]
    tmplt_bank.table = templates

if args.mchirp_sort:
    from pycbc import pnutils
    mcsort = numpy.argsort(templates.mchirp)
    templates = templates[mcsort]
    tmplt_bank.table = templates

# Split the templates in the bank taken as input into the smaller banks

# If an array of filenames
if args.output_filenames:
    args.number_of_banks = len(args.output_filenames)

# If the number of output banks is taken as input calculate the number
# of templates to be stored per bank
if args.number_of_banks:
    num_files = args.number_of_banks
    num_per_file = int(templates[:].size/num_files)

# If the number of templates per bank is taken as input calculate the
# number of output banks
elif args.templates_per_bank:
    num_per_file = args.templates_per_bank
    num_files = int(templates[:].size / num_per_file)

# Generate sub-banks
logging.info("Generating the output sub-banks")
for ii in range(num_files):
    start_idx = ii * num_per_file
    # The output banks are assigned a fixed length equal to the number
    # of templates per bank requested by the user or calculated earlier
    # in the code except for the last bank in which the remaining
    # templates, if any, are put.
    if ( ii == (num_files-1)):
        end_idx = templates[:].size
    else:
        end_idx = (ii + 1) * num_per_file

    # Assign a name to the h5py output file to store the ii'th smaller bank
    if args.output_filenames:
        outname = args.output_filenames[ii]
    elif args.output_prefix:
        outname = args.output_prefix + str(ii) + '.hdf'
    else:
        raise RuntimeError("I shouldn't be able to reach this point. One out "
                           "of --output-filenames and --output-prefix must "
                           "have been supplied!")

    # Generate the hdf5 output file for the ii'th sub-bank, which would
    # be a slice of the input template bank having a start index and
    # end index as calculated above
    output = tmplt_bank.write_to_hdf(outname, start_idx, end_idx,
                                     force=args.force)
    output.close()

logging.info("finished")
back to top