Revision d44b370e336fd20143c901199f90ab5e2aa19a46 authored by Alex Nitz on 30 April 2017, 15:47:25 UTC, committed by GitHub on 30 April 2017, 15:47:25 UTC
1 parent 29416a2
Raw File
pycbc_banksim_match_combine
#!/usr/bin/env python

# Copyright (C) 2016 Ian W. Harry, Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Program for concatenating output files from pycbc_banksim_combine_banks with
a set of injection files. The *order* of the injection files *must* match the
bank files, and the number of injections in each must correspond one-to-one.
"""

from os.path import isfile
import imp
import argparse
import numpy
import h5py
from pycbc_glue.ligolw import utils, table
import pycbc
from pycbc import pnutils
from pycbc.waveform import TemplateBank

__author__  = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = pycbc.version.git_verbose_msg
__date__    = pycbc.version.date
__program__ = "pycbc_banksim_match_combine"

from pycbc_glue.ligolw.ligolw import LIGOLWContentHandler
from pycbc_glue.ligolw import lsctables
class mycontenthandler(LIGOLWContentHandler):
    pass
lsctables.use_in(mycontenthandler)

# Read command line options
_desc = __doc__[1:]
parser = argparse.ArgumentParser(description=_desc)

parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("--verbose", action="store_true", default=False,
                    help="verbose output")
parser.add_argument("--match-files", nargs='+',
                    help="Explicit list of match files.")
#parser.add_argument("--inj-files", nargs='+',
#                    help="Explicit list of injection files. These must be in "
#                         "the same order, and match one-to-one with the "
#                         "match-files.")
parser.add_argument("-o", "--output-file", required=True,
                    help="Output file name")
parser.add_argument("--filter-func-file", default=None,
                    help="This can be provided to give a function to define "
                         "which points are covered by the template bank "
                         "bounds, and which are not. The file should contain "
                         "a function called filter_injections, which should "
                         "take as call profile, mass1, mass2, spin1z, spin2z, "
                         "as numpy arrays.")
options = parser.parse_args()
dtypem={'names': ('match', 'bank', 'bank_i', 'sim', 'sim_i', 'sigmasq'),
        'formats': ('f8', 'S256', 'i4', 'S256', 'i4', 'f8')}

# Collect the results
res = None
for fil in options.match_files:
    if res is not None:
        res = numpy.append(res, numpy.loadtxt(fil, dtype=dtypem))
    else:
        res = numpy.loadtxt(fil, dtype=dtypem)

btables = {}
itables = {}     

f = h5py.File(options.output_file, "w")

# If we more these over to HDF we can decide column names dynamically. This is
# why I'm using a dictionary now.

bank_params = {}
bank_par_list = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
                 'spin2y', 'spin2z']
for val in bank_par_list:
    bank_params[val] = numpy.zeros(len(res), dtype=numpy.float64)

inj_params = {}
inj_par_list = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
                'spin2y', 'spin2z', 'coa_phase', 'inclination', 'latitude',
                'longitude', 'polarization']
for val in inj_par_list:
    inj_params[val] = numpy.zeros(len(res), dtype=numpy.float64)

trig_params = {}
trig_par_list = ['match', 'sigmasq']
for val in trig_par_list:
    trig_params[val] = numpy.zeros(len(res), dtype=numpy.float64)

for idx, row in enumerate(res): 
    outstr = ""
    if row['bank'] not in btables:
        temp_bank = TemplateBank(row['bank'])
        btables[row['bank']] = temp_bank.table

    if row['sim'] not in itables:
        indoc = utils.load_filename(row['sim'], False, contenthandler=mycontenthandler)
        itables[row['sim']] = table.get_table(indoc, "sim_inspiral") 
    
    bt = btables[row['bank']][row['bank_i']]     
    it = itables[row['sim']][row['sim_i']]
 
    for val in trig_par_list:
        trig_params[val][idx] = row[val]
    for val in bank_par_list:
        try:
            bank_params[val][idx] = getattr(bt, val)
        except AttributeError:
            # If not present set to 0.
            # For example spin1x is not always stored in aligned-spin banks
            bank_params[val][idx] = 0.
    for val in inj_par_list:
        inj_params[val][idx] = getattr(it, val)

for val in bank_par_list:
    f['bank_params/{}'.format(val)] = bank_params[val] 
for val in inj_par_list:
    f['inj_params/{}'.format(val)] = inj_params[val]
for val in trig_par_list:
    f['trig_params/{}'.format(val)] = trig_params[val]

if options.filter_func_file:
    modl = imp.load_source('filter_func', options.filter_func_file)
    func = modl.filter_injections
    bool_arr = func(inj_params['mass1'], inj_params['mass2'],
                    inj_params['spin1z'], inj_params['spin2z'])
    bool_arr = numpy.array(bool_arr)

# Also consider values over the whole set
# Signal recovery fraction
srfn = numpy.sum((trig_params['match'] * trig_params['sigmasq'])**3.)
srfd = numpy.sum((trig_params['sigmasq'])**3.)

f['sig_rec_fac'] = srfn / srfd
f['eff_fitting_factor'] = (srfn / srfd)**(1./3.)
mchirp, _ = pnutils.mass1_mass2_to_mchirp_eta(inj_params['mass1'],
                                              inj_params['mass2'])
srfn_mcweighted = numpy.sum((trig_params['match'] * mchirp**(-5./6.) *\
                             trig_params['sigmasq'])**3.)
srfd_mcweighted = numpy.sum((mchirp**(-5./6.) * trig_params['sigmasq'])**3.)
f['sig_rec_fac_chirp_mass_weighted'] = srfn_mcweighted / srfd_mcweighted
f['eff_fitting_factor_chirp_mass_weighted'] = \
    (srfn_mcweighted / srfd_mcweighted)**(1./3.)

if options.filter_func_file:
    num_filtered = len(inj_params['mass1'][bool_arr])
    if num_filtered == 0:
        f['frac_points_within_bank'] = 0
        f['filtered_sig_rec_fac'] = -1
        f['filtered_eff_fitting_factor'] = -1
    else:
        f['frac_points_within_bank'] = \
            num_filtered / float(len(inj_params['mass1']))
        filt_match = trig_params['match'][bool_arr]
        filt_sigmasq = trig_params['sigmasq'][bool_arr]
        srfn_filt = numpy.sum((filt_match * filt_sigmasq)**3.)
        srfd_filt = numpy.sum(filt_sigmasq**3)
        f['filtered_sig_rec_fac'] = srfn_filt / srfd_filt
        f['filtered_eff_fitting_factor'] = (srfn_filt / srfd_filt)**(1./3.)
        mchirp = mchirp[bool_arr]
        srfn_mcweighted = numpy.sum((filt_match * mchirp**(-5./6.) *\
                                     filt_sigmasq)**3.)
        srfd_mcweighted = numpy.sum((mchirp**(-5./6.) * filt_sigmasq)**3.)
        f['filtered_sig_rec_fac_chirp_mass_weighted'] = \
            srfn_mcweighted / srfd_mcweighted
        f['filtered_eff_fitting_factor_chirp_mass_weighted'] = \
            (srfn_mcweighted / srfd_mcweighted)**(1./3.)

    f['filtered_points'] = bool_arr

f.close()
back to top