#! /usr/bin/env python # Copyright (C) 2012 Alex Nitz # This program is free software; you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by the # Free Software Foundation; either version 3 of the License, or (at your # option) any later version. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General # Public License for more details. # # You should have received a copy of the GNU General Public License along # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Calculate the fitting factors of simulated signals with a template bank.""" from __future__ import print_function import logging from tqdm import tqdm from numpy import complex64, array from argparse import ArgumentParser from glue.ligolw import utils as ligolw_utils from glue.ligolw import table, lsctables from glue.ligolw.ligolw import LIGOLWContentHandler class mycontenthandler(LIGOLWContentHandler): pass lsctables.use_in(mycontenthandler) from pycbc.pnutils import mass1_mass2_to_mchirp_eta from pycbc.pnutils import mass1_mass2_to_tau0_tau3 from pycbc.pnutils import nearest_larger_binary_number from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants from pycbc.waveform.utils import taper_timeseries from pycbc import DYN_RANGE_FAC from pycbc.types import FrequencySeries, TimeSeries, zeros, complex_same_precision_as from pycbc.filter import match, sigmasq from math import ceil, log import pycbc.psd, pycbc.scheme, pycbc.fft, pycbc.strain, pycbc.version from pycbc.detector import overhead_antenna_pattern as generate_fplus_fcross from pycbc.waveform import TemplateBank ## Remove the need for these functions ######################################## def generate_detector_strain(template_params, h_plus, h_cross): latitude = 0 longitude = 0 polarization = 0 if hasattr(template_params, 'latitude'): latitude = template_params.latitude if hasattr(template_params, 'longitude'): longitude = template_params.longitude if hasattr(template_params, 'polarization'): polarization = template_params.polarization f_plus, f_cross = generate_fplus_fcross(longitude, latitude, polarization) return h_plus * f_plus + h_cross * f_cross def make_padded_frequency_series(vec, filter_N=None, delta_f=None): """Convert vec (TimeSeries or FrequencySeries) to a FrequencySeries. If filter_N and/or delta_f are given, the output will take those values. If not told otherwise the code will attempt to pad a timeseries first such that the waveform will not wraparound. However, if delta_f is specified to be shorter than the waveform length then wraparound *will* be allowed. """ if filter_N is None: power = ceil(log(len(vec), 2)) + 1 N = 2 ** power else: N = filter_N n = N // 2 + 1 if isinstance(vec, FrequencySeries): vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)), delta_f=1.0, copy=False) if len(vectilde) < len(vec): cplen = len(vectilde) else: cplen = len(vec) vectilde[0:cplen] = vec[0:cplen] delta_f = vec.delta_f elif isinstance(vec, TimeSeries): # First determine if the timeseries is too short for the specified df # and increase if necessary curr_length = len(vec) new_length = int(nearest_larger_binary_number(curr_length)) while new_length * vec.delta_t < 1./delta_f: new_length = new_length * 2 vec.resize(new_length) # Then convert to frequencyseries v_tilde = vec.to_frequencyseries() # Then convert frequencyseries to required length and spacing by keeping # only every nth sample if delta_f needs increasing, and cutting at # Nyquist if the max frequency is too high. # NOTE: This assumes that the input and output data is using binary # lengths. i_delta_f = v_tilde.get_delta_f() v_tilde = v_tilde.numpy() df_ratio = int(delta_f / i_delta_f) n_freq_len = int((n-1) * df_ratio +1) assert(n <= len(v_tilde)) df_ratio = int(delta_f / i_delta_f) v_tilde = v_tilde[:n_freq_len:df_ratio] vectilde = FrequencySeries(v_tilde, delta_f=delta_f, dtype=complex64) return FrequencySeries(vectilde * DYN_RANGE_FAC, delta_f=delta_f, dtype=complex64) def get_waveform(approximant, phase_order, amplitude_order, spin_order, wf_params, start_frequency, sample_rate, length, filter_rate): delta_f = filter_rate / length if approximant in fd_approximants(): hp, hc = get_fd_waveform(wf_params, approximant=approximant, phase_order=phase_order, delta_f=delta_f, spin_order=spin_order, f_lower=start_frequency, amplitude_order=amplitude_order) hvec = generate_detector_strain(wf_params, hp, hc) elif approximant in td_approximants(): hp, hc = get_td_waveform(wf_params, approximant=approximant, phase_order=phase_order, spin_order=spin_order, delta_t=1./sample_rate, f_lower=start_frequency, amplitude_order=amplitude_order) if hasattr(wf_params, 'taper'): hp = taper_timeseries(hp, wf_params.taper) hc = taper_timeseries(hc, wf_params.taper) hvec = generate_detector_strain(wf_params, hp, hc) return make_padded_frequency_series(hvec, filter_N, delta_f=delta_f) aprs = sorted(list(set(td_approximants() + fd_approximants()))) #File output Settings parser = ArgumentParser(description=__doc__) parser.add_argument("--match-file", dest="out_file", metavar="FILE", required=True, help="File to output match results") parser.add_argument("--verbose", action='store_true', default=False, help="Print verbose statements") parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg) #Template Settings parser.add_argument("--template-file", dest="bank_file", metavar="FILE", required=True, help="SimInspiral or SnglInspiral XML file " "containing the template parameters") parser.add_argument("--total-mass-divide", type=float, help="Total mass to switch from --template-approximant to " "--highmass-approximant.") parser.add_argument("--highmass-approximant", choices=aprs, help="Waveform approximant for highmass templates.") parser.add_argument("--template-approximant", choices=aprs, required=True, help="Waveform approximant for templates") parser.add_argument("--template-phase-order", default=-1, type=int, help="PN order to use for the template phase") parser.add_argument("--template-amplitude-order", default=-1, type=int, help="PN order to use for the template amplitude") parser.add_argument("--template-spin-order", default=-1, type=int, help="PN order to use for the template spin terms") parser.add_argument("--template-start-frequency", type=float, help="Starting frequency for templates [Hz]") parser.add_argument("--template-sample-rate", type=float, help="Sample rate for templates [Hz]") #Signal Settings parser.add_argument("--signal-file", dest="sim_file", metavar="FILE", required=True, help="SimInspiral or SnglInspiral XML file " "containing the signal parameters") parser.add_argument("--signal-approximant", choices=aprs, required=True, help="Waveform approximant for signals") parser.add_argument("--signal-phase-order", default=-1, type=int, help="PN order to use for the signal phase") parser.add_argument("--signal-spin-order", default=-1, type=int, help="PN order to use for the signal spin terms") parser.add_argument("--signal-amplitude-order", default=-1, type=int, help="PN order to use for the signal amplitude") parser.add_argument("--signal-start-frequency", type=float, help="Starting frequency for signals [Hz]") parser.add_argument("--signal-sample-rate", type=float, help="Sample rate for signals [Hz]") parser.add_argument("--use-sky-location", action='store_true', help="Inject into a theoretical detector at the celestial " "North pole of a non-rotating Earth rather than overhead") #Filtering Settings parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ', type=float, required=True, help='low frequency cutoff of matched filter') parser.add_argument("--filter-sample-rate", type=float, required=True, help="Filter sample rate [Hz]") parser.add_argument("--filter-signal-length", type=int, required=True, help="Length of signal for filtering, shoud be longer " "than all waveforms and include some padding") # add PSD options pycbc.psd.insert_psd_option_group(parser, output=False) # Insert the data reading options pycbc.strain.insert_strain_option_group(parser) #hardware support pycbc.scheme.insert_processing_option_group(parser) pycbc.fft.insert_fft_option_group(parser) #Restricted maximization parser.add_argument("--mchirp-window", type=str, metavar="FRACTION", help="Ignore templates whose chirp mass deviates from " "signal's one more than given fraction. Provide two " "comma separated numbers to have different bounds " "above and below the signal's, with below bound " "listed first.") parser.add_argument("--tau0-window", type=float, metavar="TIME", default=None, help="Ignore templates whose Newtonian order chirp time " "(tau0) varies from the signals by more than the " "supplied amount. If option is not provided no " "window on tau0 is used. The " "filter-low-frequency-cutoff is used to calculate " "the value of tau0 for all cases. Provided in units " "of seconds.") options = parser.parse_args() pycbc.init_logging(options.verbose) pycbc.psd.verify_psd_options(options, parser) if options.psd_estimation: pycbc.strain.verify_strain_options(options, parser) if options.total_mass_divide and options.highmass_approximant is None: parser.error("You must provide a highmass-approximant if you want total-mass-divide.") if options.mchirp_window is None: def outside_mchirp_window(template_mchirp, signal_mchirp): return False elif ',' in options.mchirp_window: # asymmetric chirp mass window mchirp_window_lower = float(options.mchirp_window.split(",")[0]) mchirp_window_upper = float(options.mchirp_window.split(",")[1]) def outside_mchirp_window(template_mchirp, signal_mchirp): delta = (template_mchirp - signal_mchirp) / signal_mchirp return delta > mchirp_window_upper or -delta > mchirp_window_lower else: # symmetric chirp mass window mchirp_window = float(options.mchirp_window) def outside_mchirp_window(template_mchirp, signal_mchirp): return abs(signal_mchirp - template_mchirp) > \ (mchirp_window * signal_mchirp) if options.tau0_window is None: def outside_tau0_window(template_tau0, signal_tau0, window): return False else: def outside_tau0_window(template_tau0, signal_tau0, window): return abs(signal_tau0 - template_tau0) > window # If we are going to use h(t) to estimate a PSD we need h(t) if options.psd_estimation: logging.info("Obtaining h(t) for PSD generation") strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC) else: strain = None if options.template_sample_rate is not None: template_sample_rate = options.template_sample_rate else: template_sample_rate = options.filter_sample_rate if options.signal_sample_rate is not None: signal_sample_rate = options.signal_sample_rate else: signal_sample_rate = options.filter_sample_rate ctx = pycbc.scheme.from_cli(options) logging.info('Reading template bank') temp_bank = TemplateBank(options.bank_file) template_table = temp_bank.table logging.info(" %d templates", len(template_table)) logging.info('Reading simulation list') indoc = ligolw_utils.load_filename(options.sim_file, False, contenthandler=mycontenthandler) try: signal_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName) except ValueError: signal_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName) logging.info(" %d signal waveforms", len(signal_table)) logging.info("Matches will be written to %s", options.out_file) filter_N = int(options.filter_signal_length * options.filter_sample_rate) filter_n = filter_N // 2 + 1 filter_delta_f = 1.0 / float(options.filter_signal_length) logging.info("Reading and Interpolating PSD") psd = pycbc.psd.from_cli(options, filter_n, filter_delta_f, options.filter_low_frequency_cutoff, strain=strain, dyn_range_factor=pycbc.DYN_RANGE_FAC, precision='single') with ctx: pycbc.fft.from_cli(options) logging.info("Pregenerating Signals") signals = [] # Used for getting mchirp/tau0 later sig_m1 = [] sig_m2 = [] prog = tqdm(total=len(signal_table), disable=(not options.verbose)) for index, signal_params in enumerate(signal_table): prog.update(1) if not options.use_sky_location and hasattr(signal_params, 'latitude'): signal_params.latitude = 0. signal_params.longitude = 0. stilde = get_waveform(options.signal_approximant, options.signal_phase_order, options.signal_amplitude_order, options.signal_spin_order, signal_params, options.signal_start_frequency, signal_sample_rate, filter_N, options.filter_sample_rate) s_norm = sigmasq(stilde, psd=psd, low_frequency_cutoff=options.filter_low_frequency_cutoff) stilde /= psd signals.append((stilde, s_norm, [], signal_params)) sig_m1.append(signal_params.mass1) sig_m2.append(signal_params.mass2) prog.close() sig_m1 = array(sig_m1) sig_m2 = array(sig_m2) sig_tau0, _ = mass1_mass2_to_tau0_tau3(sig_m1, sig_m2, options.filter_low_frequency_cutoff) sig_mchirp, _ = mass1_mass2_to_mchirp_eta(sig_m1, sig_m2) logging.info("Calculating Mchirp and Tau0") template_m1 = array([tp.mass1 for tp in template_table]) template_m2 = array([tp.mass2 for tp in template_table]) template_tau0, _ = mass1_mass2_to_tau0_tau3(template_m1, template_m2, options.filter_low_frequency_cutoff) template_mchirp, _ = mass1_mass2_to_mchirp_eta(template_m1, template_m2) logging.info("Calculating Overlaps") flow_warned = False prog = tqdm(total=len(template_table), disable=(not options.verbose)) for index, template_params in enumerate(template_table): prog.update(1) f_lower = template_params.f_lower # If not set fall back on filter low-freq cutoff if f_lower < 0.000001: f_lower = options.filter_low_frequency_cutoff if f_lower < options.filter_low_frequency_cutoff: # Not entirely clear what to do here? if not flow_warned: logging.warn("Template's flower is smaller than " "--filter-low-frequency-cutoff. Raising flower " "of template to match.") flow_warned=True f_lower = options.filter_low_frequency_cutoff h_norm = htilde = None for sidx, (stilde, s_norm, matches, signal_params) in enumerate(signals): # Check if we need to look at this check_logic = stilde is None check_logic |= outside_tau0_window(template_tau0[index], sig_tau0[sidx], options.tau0_window) check_logic |= outside_mchirp_window(template_mchirp[index], sig_mchirp[sidx]) if check_logic: matches.append(0) continue # Generate htilde if we haven't already done so if htilde is None: # FIXME: I would like to remove the approximant options and # have this entirely controlled by the template bank. # However, while we are still using the high-mass divide # in XML banks, this must be retained. try: this_approximant = template_params['approximant'] except: this_approximant = options.template_approximant if options.total_mass_divide is not None and (template_params.mass1+template_params.mass2) >= options.total_mass_divide: this_approximant = options.highmass_approximant htilde = get_waveform(this_approximant, options.template_phase_order, options.template_amplitude_order, options.template_spin_order, template_params, options.template_start_frequency, template_sample_rate, filter_N, options.filter_sample_rate) h_norm = sigmasq(htilde, psd=psd, low_frequency_cutoff=f_lower) o, i = match(htilde, stilde, v1_norm=h_norm, v2_norm=s_norm, low_frequency_cutoff=f_lower) matches.append(o) prog.close() logging.info("Determining maximum overlaps and outputting results") # Find the maximum overlap in the bank and output to a file with open(options.out_file, "w") as fout: for i, (stilde, s_norm, matches, sim_template) in enumerate(signals): match_str = "%5.5f " % max(matches) match_str += " " + options.bank_file match_str += " " + str(matches.index(max(matches))) match_str += " " + options.sim_file match_str += " %d" % i match_str += " %5.5f\n" % s_norm fout.write(match_str)