Revision 29df1b564abf7a6ed5230a5e97a85da70c403c2a authored by Ian Harry on 15 June 2017, 14:00:24 UTC, committed by Duncan Brown on 15 June 2017, 14:00:24 UTC
1 parent 23fd945
Raw File
pycbc_faithsim
#! /usr/bin/env python
## Copyright (C) 2012  Alex Nitz
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
import sys
import logging
from numpy import loadtxt,complex64,float32
import argparse
from pycbc_glue.ligolw import utils as ligolw_utils
from pycbc_glue.ligolw import table, lsctables, ligolw

import pycbc.version
import pycbc.strain
import pycbc.psd
from pycbc.pnutils import mass1_mass2_to_mchirp_eta
from pycbc.waveform import td_approximants, fd_approximants
from pycbc.waveform import get_two_pol_waveform_filter
from pycbc import DYN_RANGE_FAC
from pycbc.types import FrequencySeries, zeros
from pycbc.filter import match, overlap, sigma
from pycbc.scheme import CPUScheme, CUDAScheme

class ContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(ContentHandler)

def update_progress(progress):
    print '\r\r[{0}] {1}%'.format('#'*(progress/2)+' '*(50-progress/2), progress),
    if progress == 100:
        print "Done"
    sys.stdout.flush()


def get_waveform(approximant, phase_order, amplitude_order, spin_order, tapering, template_params, start_frequency, sample_rate, length):

    delta_f = sample_rate / length
    vecplus = FrequencySeries(zeros(filter_n), delta_f=delta_f,
                              dtype=complex64)
    veccross = FrequencySeries(zeros(filter_n), delta_f=delta_f,
                              dtype=complex64)

    if tapering is not None:
       curr_taper = tapering
    elif hasattr(template_params, 'taper'):
       curr_taper = template_params.taper
    else:
       curr_taper = None

    # NOTE: for now only hplus is used! For precessing faithsims one would want
    #       to also specify a polarization phase, or "u_val".
    hplus, hcross = get_two_pol_waveform_filter(vecplus, veccross,
        template_params, approximant=approximant, spin_order=spin_order,
        phase_order=phase_order, delta_t=1.0 / sample_rate, delta_f=delta_f,
        f_lower=start_frequency, amplitude_order=amplitude_order,
        taper=curr_taper)

    return hplus*DYN_RANGE_FAC

###############################################################################

aprs = list(set(td_approximants() + fd_approximants()))
psd_names = pycbc.psd.get_lalsim_psd_list()

#File I/O Settings
taper_choices = ["start","end","startend"]
parser = argparse.ArgumentParser(usage='',
    description="Calculate faithfulness for a set of waveforms.")
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="print extra debugging information", default=False)
parser.add_argument("--param-file", dest="bank_file", metavar="FILE",
                    help="Sngl or Sim Inspiral Table containing waveform "
                         "parameters.")
parser.add_argument("--match-file", dest="out_file", metavar="FILE",
                    help="File to output match results to.")

#Waveform generation Settings
parser.add_argument("--waveform1-approximant", choices=aprs,
                    help="Waveform 1's approximant.")
parser.add_argument("--waveform1-phase-order", type=int, default=-1,
                    help="PN order to use for the phase")
parser.add_argument("--waveform1-amplitude-order", default=-1, type=int,
                    help="PN order to use for the amplitude.")
parser.add_argument("--waveform1-spin-order", default=-1, type=int,
                    help="Spin terms up to the given pN order are included")
parser.add_argument("--waveform1-start-frequency", type=float,
                    help="Starting frequency for waveform generation.")
parser.add_argument("--waveform1-taper-template", choices=taper_choices,
                    default=None,
                    help="For time-domain approximants, taper the start and/or"
                    " end of the waveform before FFTing. This can also be"
                    " provided in the sim_inspiral table. Providing this"
                    " option will override any entry in that table.")

parser.add_argument("--waveform2-approximant", choices=aprs,
                    help="Waveform 2's approximant.")
parser.add_argument("--waveform2-phase-order", type=int, default=-1,
                    help="PN order to use for the phase")
parser.add_argument("--waveform2-amplitude-order", default=-1, type=int,
                    help="PN order to use for the amplitude.")
parser.add_argument("--waveform2-spin-order", default=-1, type=int,
                    help="Spin terms up to the given pN order are included")
parser.add_argument("--waveform2-start-frequency", type=float,
                    help="Starting frequency for waveform generation.")
parser.add_argument("--waveform2-taper-template", choices=taper_choices,
                    default=None,
                    help="For time-domain approximants, taper the start and/or"
                    " end of the waveform before FFTing. This can also be"
                    " provided in the sim_inspiral table. Providing this"
                    " option will override any entry in that table.")

#Filter Settings
parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ',
                    type=float,
                    help='Low frequency cutoff of matched filter.')
parser.add_argument('--filter-high-frequency-cutoff', metavar='FREQ',
                    type=float,
                    help='High frequency cutoff of matched filter.')
parser.add_argument("--filter-sample-rate", type=float,
                    help="Filter Sample Rate [Hz]")
parser.add_argument("--filter-waveform-length", type=int,
                    help="Length of waveform for filtering, should be longer "
                         "than all waveforms and include some padding.")

parser.add_argument("--cuda", action="store_true",
                    help="Use CUDA for calculations.")

# Insert the PSD options
pycbc.psd.insert_psd_option_group(parser)

# Insert the data reading options
pycbc.strain.insert_strain_option_group(parser)

options = parser.parse_args()

pycbc.init_logging(options.verbose)

if options.cuda:
    ctx = CUDAScheme()
else:
    ctx = CPUScheme()

# Load in the waveform1 bank file
indoc = ligolw_utils.load_filename(options.bank_file, False, contenthandler=ContentHandler)
try :
    waveform_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName) 
except ValueError:
    waveform_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)

# open the output file where the max overlaps over the bank are stored 
fout = open(options.out_file, "w")
logging.info("Writing matches to " + options.out_file)

filter_N = options.filter_waveform_length * options.filter_sample_rate
filter_n = int(filter_N / 2) + 1
delta_f = float(options.filter_sample_rate) / filter_N

logging.info("Number of Waveforms      : %d" % len(waveform_table))

logging.info("Reading and Interpolating PSD")
# If we are going to use h(t) to estimate a PSD we need h(t)
if options.psd_estimation:
    logging.info("Obtaining h(t) for PSD generation")
    strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC)
else:
    strain = None

psd = pycbc.psd.from_cli(options, length=filter_n, delta_f=delta_f,
    low_frequency_cutoff=options.filter_low_frequency_cutoff, strain=strain,
    dyn_range_factor=DYN_RANGE_FAC, precision='single')

matches = []
overlaps = []
time_offsets = []
s1s = []
s2s = []
logging.info("Calculating Overlaps")
with ctx:
    index = 0 
    # Calculate the overlaps
    for waveform_params in waveform_table:
        index += 1
        if options.verbose:
            update_progress(index*100/len(waveform_table))

        try:
            htilde1 = get_waveform(options.waveform1_approximant, 
                                  options.waveform1_phase_order, 
                                  options.waveform1_amplitude_order,
                                  options.waveform1_spin_order, 
                                  options.waveform1_taper_template,
                                  waveform_params, 
                                  options.waveform1_start_frequency, 
                                  options.filter_sample_rate, 
                                  filter_N)
             
            htilde2 = get_waveform(options.waveform2_approximant, 
                                  options.waveform2_phase_order, 
                                  options.waveform2_amplitude_order,
                                  options.waveform2_spin_order, 
                                  options.waveform2_taper_template,
                                  waveform_params, 
                                  options.waveform2_start_frequency, 
                                  options.filter_sample_rate, 
                                  filter_N)

            m,i = match(htilde1, htilde2, psd=psd, 
                low_frequency_cutoff=options.filter_low_frequency_cutoff,
                high_frequency_cutoff=options.filter_high_frequency_cutoff)

            o = overlap(htilde1, htilde2, psd=psd, 
                low_frequency_cutoff=options.filter_low_frequency_cutoff,
                high_frequency_cutoff=options.filter_high_frequency_cutoff)

            s1 = sigma(htilde1, psd=psd,
                low_frequency_cutoff=options.filter_low_frequency_cutoff,
                high_frequency_cutoff=options.filter_high_frequency_cutoff)
            s2 = sigma(htilde2, psd=psd,
                low_frequency_cutoff=options.filter_low_frequency_cutoff,
                high_frequency_cutoff=options.filter_high_frequency_cutoff)
            matches.append(m)
            overlaps.append(o)
            if i > filter_n:
               i = i - filter_N
            time_offsets.append(i * 1./options.filter_sample_rate)
            s1s.append(s1)
            s2s.append(s2)
        except Exception as e:
            logging.warning("Unable to generate waveforms")
            logging.warning("Error: %s, %s", str(type(e)), str(e))
            matches.append(-1)
            overlaps.append(-1)
            time_offsets.append(-1)
            s1s.append(-1)
            s2s.append(-1)

#Output the overlaps to  a file
for m, o, i, s1, s2 in zip(matches, overlaps, time_offsets, s1s, s2s):
    match_str= "%5.5f %5.5f %5.5f %5.5f %5.5f\n" % (m, o, i, s1, s2)
    fout.write(match_str)
back to top