https://github.com/gwastro/pycbc
Raw File
Tip revision: b574af5aeb8e5b22c040f0e6cecd6796e6a5cb69 authored by Duncan Brown on 27 November 2018, 23:45:28 UTC
Set for 1.13.2 release
Tip revision: b574af5
pycbc_banksim
#! /usr/bin/env python
# Copyright (C) 2012  Alex Nitz
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Calculate the fitting factors of simulated signals with a template bank."""

import sys
import logging
from numpy import complex64, float32, array
from argparse import ArgumentParser
from glue.ligolw import utils as ligolw_utils
from glue.ligolw import table, lsctables
from glue.ligolw.ligolw import LIGOLWContentHandler
class mycontenthandler(LIGOLWContentHandler):
    pass
lsctables.use_in(mycontenthandler)

from pycbc.pnutils import mass1_mass2_to_mchirp_eta, f_SchwarzISCO
from pycbc.pnutils import mass1_mass2_to_tau0_tau3
from pycbc.pnutils import nearest_larger_binary_number
from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants
from pycbc.waveform.utils import taper_timeseries
from pycbc import DYN_RANGE_FAC
from pycbc.types import FrequencySeries, TimeSeries, zeros, real_same_precision_as, complex_same_precision_as
from pycbc.filter import match, sigmasq, resample_to_delta_t
from math import ceil, log
import pycbc.psd, pycbc.scheme, pycbc.fft, pycbc.strain, pycbc.version
from pycbc.detector import overhead_antenna_pattern as generate_fplus_fcross
from pycbc.waveform import TemplateBank

def update_progress(progress):
    bar = '#' * (int(progress*100)/2) + ' ' * (50-int(progress*100)/2)
    print '[{0}] {1:.2%}\r\r'.format(bar, progress),
    if progress == 1:
        print('Done' + ' ' * 70)
    sys.stdout.flush()

## Remove the need for these functions ########################################
    
def generate_detector_strain(template_params, h_plus, h_cross):
    latitude = 0 
    longitude = 0 
    polarization = 0 

    if hasattr(template_params, 'latitude'):
        latitude = template_params.latitude
    if hasattr(template_params, 'longitude'):
        longitude = template_params.longitude
    if hasattr(template_params, 'polarization'):
        polarization = template_params.polarization

    f_plus, f_cross = generate_fplus_fcross(longitude, latitude, polarization)

    return h_plus * f_plus + h_cross * f_cross

def make_padded_frequency_series(vec, filter_N=None, delta_f=None):
    """Convert vec (TimeSeries or FrequencySeries) to a FrequencySeries. If
    filter_N and/or delta_f are given, the output will take those values. If
    not told otherwise the code will attempt to pad a timeseries first such that
    the waveform will not wraparound. However, if delta_f is specified to be
    shorter than the waveform length then wraparound *will* be allowed.
    """
    if filter_N is None:
        power = ceil(log(len(vec), 2)) + 1
        N = 2 ** power
    else:
        N = filter_N
    n = N / 2 + 1

    if isinstance(vec, FrequencySeries):
        vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)),
                                   delta_f=1.0, copy=False)
        if len(vectilde) < len(vec):
            cplen = len(vectilde)
        else:
            cplen = len(vec)
        vectilde[0:cplen] = vec[0:cplen]
        delta_f = vec.delta_f

    elif isinstance(vec, TimeSeries):
        # First determine if the timeseries is too short for the specified df
        # and increase if necessary
        curr_length = len(vec)
        new_length = int(nearest_larger_binary_number(curr_length))
        while new_length * vec.delta_t < 1./delta_f:
            new_length = new_length * 2
        vec.resize(new_length)
        # Then convert to frequencyseries
        v_tilde = vec.to_frequencyseries()
        # Then convert frequencyseries to required length and spacing by keeping
        # only every nth sample if delta_f needs increasing, and cutting at
        # Nyquist if the max frequency is too high.
        # NOTE: This assumes that the input and output data is using binary
        #       lengths.
        i_delta_f = v_tilde.get_delta_f()
        v_tilde = v_tilde.numpy()
        df_ratio = int(delta_f / i_delta_f)
        n_freq_len = int((n-1) * df_ratio +1)
        assert(n <= len(v_tilde))
        df_ratio = int(delta_f / i_delta_f)
        v_tilde = v_tilde[:n_freq_len:df_ratio]
        vectilde = FrequencySeries(v_tilde, delta_f=delta_f, dtype=complex64)

    return FrequencySeries(vectilde * DYN_RANGE_FAC, delta_f=delta_f,
                           dtype=complex64)

def get_waveform(approximant, phase_order, amplitude_order, spin_order,
                 wf_params, start_frequency, sample_rate, length,
                 filter_rate):

    delta_f = filter_rate / length
    if approximant in fd_approximants():
        hp, hc = get_fd_waveform(wf_params, approximant=approximant,
                                 phase_order=phase_order, delta_f=delta_f,
                                 spin_order=spin_order,
                                 f_lower=start_frequency,
                                 amplitude_order=amplitude_order) 
        hvec = generate_detector_strain(wf_params, hp, hc)

    elif approximant in td_approximants():
        hp, hc = get_td_waveform(wf_params,
                                 approximant=approximant,
                                 phase_order=phase_order,
                                 spin_order=spin_order,
                                 delta_t=1./sample_rate,
                                 f_lower=start_frequency,
                                 amplitude_order=amplitude_order)
        if hasattr(wf_params, 'taper'):
            hp = taper_timeseries(hp, wf_params.taper)
            hc = taper_timeseries(hc, wf_params.taper)
        hvec = generate_detector_strain(wf_params, hp, hc)

    return make_padded_frequency_series(hvec, filter_N, delta_f=delta_f)

aprs = sorted(list(set(td_approximants() + fd_approximants())))

#File output Settings
parser = ArgumentParser(description=__doc__)
parser.add_argument("--match-file", dest="out_file", metavar="FILE",
                    required=True, help="File to output match results")
parser.add_argument("--verbose", action='store_true', default=False,
                    help="Print verbose statements")
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)

#Template Settings
parser.add_argument("--template-file", dest="bank_file", metavar="FILE",
                    required=True, help="SimInspiral or SnglInspiral XML file "
                                        "containing the template parameters")
parser.add_argument("--total-mass-divide", type=float,
                    help="Total mass to switch from --template-approximant to "
                         "--highmass-approximant.")
parser.add_argument("--highmass-approximant", choices=aprs,
                    help="Waveform approximant for highmass templates.")
parser.add_argument("--template-approximant", choices=aprs, required=True,
                    help="Waveform approximant for templates")
parser.add_argument("--template-phase-order", default=-1, type=int,
                    help="PN order to use for the template phase")
parser.add_argument("--template-amplitude-order", default=-1, type=int,
                    help="PN order to use for the template amplitude")
parser.add_argument("--template-spin-order", default=-1, type=int,
                    help="PN order to use for the template spin terms")
parser.add_argument("--template-start-frequency", type=float,
                    help="Starting frequency for templates [Hz]")
parser.add_argument("--template-sample-rate", type=float,
                    help="Sample rate for templates [Hz]")

#Signal Settings
parser.add_argument("--signal-file", dest="sim_file", metavar="FILE",
                    required=True, help="SimInspiral or SnglInspiral XML file "
                                        "containing the signal parameters")
parser.add_argument("--signal-approximant", choices=aprs, required=True,
                    help="Waveform approximant for signals")
parser.add_argument("--signal-phase-order", default=-1, type=int,
                    help="PN order to use for the signal phase")
parser.add_argument("--signal-spin-order", default=-1, type=int,
                    help="PN order to use for the signal spin terms")
parser.add_argument("--signal-amplitude-order", default=-1, type=int,
                    help="PN order to use for the signal amplitude")
parser.add_argument("--signal-start-frequency", type=float,
                    help="Starting frequency for signals [Hz]")
parser.add_argument("--signal-sample-rate", type=float,
                    help="Sample rate for signals [Hz]")
parser.add_argument("--use-sky-location", action='store_true',
                    help="Inject into a theoretical detector at the celestial "
                         "North pole of a non-rotating Earth rather than overhead")

#Filtering Settings
parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ', type=float,
                    required=True, help='low frequency cutoff of matched filter')
parser.add_argument("--filter-sample-rate", type=float, required=True,
                    help="Filter sample rate [Hz]")
parser.add_argument("--filter-signal-length", type=int, required=True,
                    help="Length of signal for filtering, shoud be longer "
                         "than all waveforms and include some padding")

# add PSD options
pycbc.psd.insert_psd_option_group(parser, output=False)

# Insert the data reading options
pycbc.strain.insert_strain_option_group(parser)

#hardware support
pycbc.scheme.insert_processing_option_group(parser)
pycbc.fft.insert_fft_option_group(parser)

#Restricted maximization
parser.add_argument("--mchirp-window", type=str, metavar="FRACTION",
                    help="Ignore templates whose chirp mass deviates from "
                         "signal's one more than given fraction. Provide two "
                         "comma separated numbers to have different bounds "
                         "above and below the signal's, with below bound "
                         "listed first.")
parser.add_argument("--tau0-window", type=float, metavar="TIME", default=None,
                    help="Ignore templates whose Newtonian order chirp time "
                         "(tau0) varies from the signals by more than the "
                         "supplied amount. If option is not provided no "
                         "window on tau0 is used. The "
                         "filter-low-frequency-cutoff is used to calculate "
                         "the value of tau0 for all cases. Provided in units "
                         "of seconds.")

options = parser.parse_args()

pycbc.init_logging(options.verbose)

pycbc.psd.verify_psd_options(options, parser)

if options.psd_estimation:
    pycbc.strain.verify_strain_options(options, parser)

if options.total_mass_divide and options.highmass_approximant is None:
    parser.error("You must provide a highmass-approximant if you want total-mass-divide.")

if options.mchirp_window is None:
    def outside_mchirp_window(template_mchirp, signal_mchirp):
        return False
elif ',' in options.mchirp_window:
    # asymmetric chirp mass window
    mchirp_window_lower = float(options.mchirp_window.split(",")[0])
    mchirp_window_upper = float(options.mchirp_window.split(",")[1])
    def outside_mchirp_window(template_mchirp, signal_mchirp):
        delta = (template_mchirp - signal_mchirp) / signal_mchirp
        return delta > mchirp_window_upper or -delta > mchirp_window_lower
else:
    # symmetric chirp mass window
    mchirp_window = float(options.mchirp_window)
    def outside_mchirp_window(template_mchirp, signal_mchirp):
        return abs(signal_mchirp - template_mchirp) > \
                (mchirp_window * signal_mchirp)

if options.tau0_window is None:
    def outside_tau0_window(template_tau0, signal_tau0, window):
        return False
else:
    def outside_tau0_window(template_tau0, signal_tau0, window):
        return abs(signal_tau0 - template_tau0) > window

# If we are going to use h(t) to estimate a PSD we need h(t)
if options.psd_estimation:
    logging.info("Obtaining h(t) for PSD generation")
    strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC)
else:
    strain = None

if options.template_sample_rate is not None:
    template_sample_rate = options.template_sample_rate
else:
    template_sample_rate = options.filter_sample_rate
if options.signal_sample_rate is not None:
    signal_sample_rate = options.signal_sample_rate
else:
    signal_sample_rate = options.filter_sample_rate

ctx = pycbc.scheme.from_cli(options)

logging.info('Reading template bank')
temp_bank = TemplateBank(options.bank_file)
template_table = temp_bank.table
logging.info("  %d templates", len(template_table))

logging.info('Reading simulation list')
indoc = ligolw_utils.load_filename(options.sim_file, False,
                                   contenthandler=mycontenthandler)
try:
    signal_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName) 
except ValueError:
    signal_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName)
logging.info("  %d signal waveforms", len(signal_table))

logging.info("Matches will be written to %s", options.out_file)

filter_N = int(options.filter_signal_length * options.filter_sample_rate)
filter_n = filter_N / 2 + 1
filter_delta_f = 1.0 / float(options.filter_signal_length)

logging.info("Reading and Interpolating PSD")
psd = pycbc.psd.from_cli(options, filter_n, filter_delta_f,
                         options.filter_low_frequency_cutoff, strain=strain,
                         dyn_range_factor=pycbc.DYN_RANGE_FAC,
                         precision='single')
  
with ctx: 
    pycbc.fft.from_cli(options)

    logging.info("Pregenerating Signals")

    signals = []
    # Used for getting mchirp/tau0 later
    sig_m1 = []
    sig_m2 = []
    for index, signal_params in enumerate(signal_table):
        if options.verbose:
            update_progress(float(index+1)/len(signal_table))
        if not options.use_sky_location:
            signal_params.latitude = 0.
            signal_params.longitude = 0.
        stilde = get_waveform(options.signal_approximant,
                              options.signal_phase_order,
                              options.signal_amplitude_order,
                              options.signal_spin_order,
                              signal_params,
                              options.signal_start_frequency,
                              signal_sample_rate,
                              filter_N, options.filter_sample_rate)
        s_norm = sigmasq(stilde, psd=psd,
                         low_frequency_cutoff=options.filter_low_frequency_cutoff)
        stilde /= psd
        signals.append((stilde, s_norm, [], signal_params))
        sig_m1.append(signal_params.mass1)
        sig_m2.append(signal_params.mass2)
    sig_m1 = array(sig_m1)
    sig_m2 = array(sig_m2)
    sig_tau0, _ = mass1_mass2_to_tau0_tau3(sig_m1, sig_m2,
                                           options.filter_low_frequency_cutoff)
    sig_mchirp, _ = mass1_mass2_to_mchirp_eta(sig_m1, sig_m2)

    logging.info("Calculating Mchirp and Tau0")
    template_m1 = array([tp.mass1 for tp in template_table])
    template_m2 = array([tp.mass2 for tp in template_table])
    template_tau0, _ = mass1_mass2_to_tau0_tau3(template_m1, template_m2,
                                                options.filter_low_frequency_cutoff)
    template_mchirp, _ = mass1_mass2_to_mchirp_eta(template_m1, template_m2)

    logging.info("Calculating Overlaps")

    flow_warned = False
    for index, template_params in enumerate(template_table):
        if options.verbose:
            update_progress(float(index+1)/len(template_table))

        f_lower = template_params.f_lower
        # If not set fall back on filter low-freq cutoff
        if f_lower < 0.000001:
            f_lower = options.filter_low_frequency_cutoff
        if f_lower < options.filter_low_frequency_cutoff:
            # Not entirely clear what to do here?
            if not flow_warned:
                logging.warn("Template's flower is smaller than "
                             "--filter-low-frequency-cutoff. Raising flower "
                             "of template to match.")
                flow_warned=True
            f_lower = options.filter_low_frequency_cutoff

        h_norm = htilde = None
        for sidx, (stilde, s_norm, matches, signal_params) in enumerate(signals):
            # Check if we need to look at this
            check_logic = stilde is None
            check_logic |= outside_tau0_window(template_tau0[index],
                                               sig_tau0[sidx],
                                               options.tau0_window)
            check_logic |= outside_mchirp_window(template_mchirp[index],
                                                 sig_mchirp[sidx])
            if check_logic:
                matches.append(0)
                continue

            # Generate htilde if we haven't already done so
            if htilde is None:
                # FIXME: I would like to remove the approximant options and
                #        have this entirely controlled by the template bank.
                #        However, while we are still using the high-mass divide
                #        in XML banks, this must be retained.
                try:
                    this_approximant = template_params['approximant']
                except:
                    this_approximant = options.template_approximant
                    if options.total_mass_divide is not None and (template_params.mass1+template_params.mass2) >= options.total_mass_divide:
                        this_approximant = options.highmass_approximant

                htilde = get_waveform(this_approximant,
                                      options.template_phase_order,
                                      options.template_amplitude_order,
                                      options.template_spin_order,
                                      template_params,
                                      options.template_start_frequency,
                                      template_sample_rate,
                                      filter_N, options.filter_sample_rate)

                h_norm = sigmasq(htilde, psd=psd, low_frequency_cutoff=f_lower)

            o, i = match(htilde, stilde, v1_norm=h_norm, v2_norm=s_norm,
                         low_frequency_cutoff=f_lower)
            matches.append(o)

logging.info("Determining maximum overlaps and outputting results")

# Find the maximum overlap in the bank and output to a file
with open(options.out_file, "w") as fout:
    for i, (stilde, s_norm, matches, sim_template) in enumerate(signals):
        match_str = "%5.5f " % max(matches)
        match_str += " " + options.bank_file
        match_str += " " + str(matches.index(max(matches)))
        match_str += " " + options.sim_file
        match_str += " %d" % i
        match_str += " %5.5f\n" % s_norm
        fout.write(match_str)
back to top