#!/usr/bin/env python
# Copyright (C) 2016 Ian W. Harry, Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""
Program for concatenating output files from pycbc_banksim_combine_banks with
a set of injection files. The *order* of the injection files *must* match the
bank files, and the number of injections in each must correspond one-to-one.
"""
from os.path import isfile
import imp
import argparse
import numpy
import h5py
from pycbc_glue.ligolw import utils, table
import pycbc
from pycbc import pnutils
from pycbc.waveform import TemplateBank
__author__ = "Ian Harry <ian.harry@astro.cf.ac.uk>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_banksim_match_combine"
from pycbc_glue.ligolw.ligolw import LIGOLWContentHandler
from pycbc_glue.ligolw import lsctables
class mycontenthandler(LIGOLWContentHandler):
pass
lsctables.use_in(mycontenthandler)
# Read command line options
_desc = __doc__[1:]
parser = argparse.ArgumentParser(description=_desc)
parser.add_argument("--version", action="version", version=__version__)
parser.add_argument("--verbose", action="store_true", default=False,
help="verbose output")
parser.add_argument("--match-files", nargs='+',
help="Explicit list of match files.")
#parser.add_argument("--inj-files", nargs='+',
# help="Explicit list of injection files. These must be in "
# "the same order, and match one-to-one with the "
# "match-files.")
parser.add_argument("-o", "--output-file", required=True,
help="Output file name")
parser.add_argument("--filter-func-file", default=None,
help="This can be provided to give a function to define "
"which points are covered by the template bank "
"bounds, and which are not. The file should contain "
"a function called filter_injections, which should "
"take as call profile, mass1, mass2, spin1z, spin2z, "
"as numpy arrays.")
options = parser.parse_args()
dtypem={'names': ('match', 'bank', 'bank_i', 'sim', 'sim_i', 'sigmasq'),
'formats': ('f8', 'S256', 'i4', 'S256', 'i4', 'f8')}
# Collect the results
res = None
for fil in options.match_files:
if res is not None:
res = numpy.append(res, numpy.loadtxt(fil, dtype=dtypem))
else:
res = numpy.loadtxt(fil, dtype=dtypem)
btables = {}
itables = {}
f = h5py.File(options.output_file, "w")
# If we more these over to HDF we can decide column names dynamically. This is
# why I'm using a dictionary now.
bank_params = {}
bank_par_list = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
'spin2y', 'spin2z']
for val in bank_par_list:
bank_params[val] = numpy.zeros(len(res), dtype=numpy.float64)
inj_params = {}
inj_par_list = ['mass1', 'mass2', 'spin1x', 'spin1y', 'spin1z', 'spin2x',
'spin2y', 'spin2z', 'coa_phase', 'inclination', 'latitude',
'longitude', 'polarization']
for val in inj_par_list:
inj_params[val] = numpy.zeros(len(res), dtype=numpy.float64)
trig_params = {}
trig_par_list = ['match', 'sigmasq']
for val in trig_par_list:
trig_params[val] = numpy.zeros(len(res), dtype=numpy.float64)
for idx, row in enumerate(res):
outstr = ""
if row['bank'] not in btables:
temp_bank = TemplateBank(row['bank'])
btables[row['bank']] = temp_bank.table
if row['sim'] not in itables:
indoc = utils.load_filename(row['sim'], False, contenthandler=mycontenthandler)
itables[row['sim']] = table.get_table(indoc, "sim_inspiral")
bt = btables[row['bank']][row['bank_i']]
it = itables[row['sim']][row['sim_i']]
for val in trig_par_list:
trig_params[val][idx] = row[val]
for val in bank_par_list:
try:
bank_params[val][idx] = getattr(bt, val)
except AttributeError:
# If not present set to 0.
# For example spin1x is not always stored in aligned-spin banks
bank_params[val][idx] = 0.
for val in inj_par_list:
inj_params[val][idx] = getattr(it, val)
for val in bank_par_list:
f['bank_params/{}'.format(val)] = bank_params[val]
for val in inj_par_list:
f['inj_params/{}'.format(val)] = inj_params[val]
for val in trig_par_list:
f['trig_params/{}'.format(val)] = trig_params[val]
if options.filter_func_file:
modl = imp.load_source('filter_func', options.filter_func_file)
func = modl.filter_injections
bool_arr = func(inj_params['mass1'], inj_params['mass2'],
inj_params['spin1z'], inj_params['spin2z'])
bool_arr = numpy.array(bool_arr)
# Also consider values over the whole set
# Signal recovery fraction
srfn = numpy.sum((trig_params['match'] * trig_params['sigmasq'])**3.)
srfd = numpy.sum((trig_params['sigmasq'])**3.)
f['sig_rec_fac'] = srfn / srfd
f['eff_fitting_factor'] = (srfn / srfd)**(1./3.)
mchirp, _ = pnutils.mass1_mass2_to_mchirp_eta(inj_params['mass1'],
inj_params['mass2'])
srfn_mcweighted = numpy.sum((trig_params['match'] * mchirp**(-5./6.) *\
trig_params['sigmasq'])**3.)
srfd_mcweighted = numpy.sum((mchirp**(-5./6.) * trig_params['sigmasq'])**3.)
f['sig_rec_fac_chirp_mass_weighted'] = srfn_mcweighted / srfd_mcweighted
f['eff_fitting_factor_chirp_mass_weighted'] = \
(srfn_mcweighted / srfd_mcweighted)**(1./3.)
if options.filter_func_file:
num_filtered = len(inj_params['mass1'][bool_arr])
if num_filtered == 0:
f['frac_points_within_bank'] = 0
f['filtered_sig_rec_fac'] = -1
f['filtered_eff_fitting_factor'] = -1
else:
f['frac_points_within_bank'] = \
num_filtered / float(len(inj_params['mass1']))
filt_match = trig_params['match'][bool_arr]
filt_sigmasq = trig_params['sigmasq'][bool_arr]
srfn_filt = numpy.sum((filt_match * filt_sigmasq)**3.)
srfd_filt = numpy.sum(filt_sigmasq**3)
f['filtered_sig_rec_fac'] = srfn_filt / srfd_filt
f['filtered_eff_fitting_factor'] = (srfn_filt / srfd_filt)**(1./3.)
mchirp = mchirp[bool_arr]
srfn_mcweighted = numpy.sum((filt_match * mchirp**(-5./6.) *\
filt_sigmasq)**3.)
srfd_mcweighted = numpy.sum((mchirp**(-5./6.) * filt_sigmasq)**3.)
f['filtered_sig_rec_fac_chirp_mass_weighted'] = \
srfn_mcweighted / srfd_mcweighted
f['filtered_eff_fitting_factor_chirp_mass_weighted'] = \
(srfn_mcweighted / srfd_mcweighted)**(1./3.)
f['filtered_points'] = bool_arr
f.close()