#! /usr/bin/env python # Copyright (C) 2012-17 Alex Nitz, Ian Harry # This program is free software; you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by the # Free Software Foundation; either version 3 of the License, or (at your # option) any later version. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General # Public License for more details. # # You should have received a copy of the GNU General Public License along # with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. """Calculate the fitting factors of simulated signals with a template bank.""" import sys import logging from numpy import complex64, float32, sqrt, argmax, real, array from argparse import ArgumentParser from glue.ligolw import utils as ligolw_utils from glue.ligolw import table, lsctables from glue.ligolw.ligolw import LIGOLWContentHandler class mycontenthandler(LIGOLWContentHandler): pass lsctables.use_in(mycontenthandler) from pycbc.pnutils import mass1_mass2_to_mchirp_eta, f_SchwarzISCO from pycbc.pnutils import nearest_larger_binary_number from pycbc.pnutils import mass1_mass2_to_tau0_tau3 from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants from pycbc.waveform.utils import taper_timeseries from pycbc import DYN_RANGE_FAC from pycbc.types import FrequencySeries, TimeSeries, zeros, real_same_precision_as, complex_same_precision_as from pycbc.filter import match, sigmasq, resample_to_delta_t from pycbc.filter import overlap_cplx, matched_filter from pycbc.filter import compute_max_snr_over_sky_loc_stat from pycbc.filter import compute_max_snr_over_sky_loc_stat_no_phase from math import ceil, log import pycbc.psd, pycbc.scheme, pycbc.fft, pycbc.strain, pycbc.version from pycbc.detector import overhead_antenna_pattern as generate_fplus_fcross from pycbc.waveform import TemplateBank def update_progress(progress): bar = '#' * (int(progress*100)/2) + ' ' * (50-int(progress*100)/2) print '[{0}] {1:.2%}\r\r'.format(bar, progress), if progress == 1: print('Done' + ' ' * 70) sys.stdout.flush() ## Remove the need for these functions ######################################## def generate_detector_strain(template_params, h_plus, h_cross): latitude = 0 longitude = 0 polarization = 0 if hasattr(template_params, 'latitude'): latitude = template_params.latitude if hasattr(template_params, 'longitude'): longitude = template_params.longitude if hasattr(template_params, 'polarization'): polarization = template_params.polarization f_plus, f_cross = generate_fplus_fcross(longitude, latitude, polarization) return h_plus * f_plus + h_cross * f_cross def make_padded_frequency_series(vec, filter_N=None, delta_f=None): """Convert vec (TimeSeries or FrequencySeries) to a FrequencySeries. If filter_N and/or delta_f are given, the output will take those values. If not told otherwise the code will attempt to pad a timeseries first such that the waveform will not wraparound. However, if delta_f is specified to be shorter than the waveform length then wraparound *will* be allowed. """ if filter_N is None: power = ceil(log(len(vec), 2)) + 1 N = 2 ** power else: N = filter_N n = N / 2 + 1 if isinstance(vec, FrequencySeries): vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)), delta_f=1.0, copy=False) if len(vectilde) < len(vec): cplen = len(vectilde) else: cplen = len(vec) vectilde[0:cplen] = vec[0:cplen] delta_f = vec.delta_f elif isinstance(vec, TimeSeries): # First determine if the timeseries is too short for the specified df # and increase if necessary curr_length = len(vec) new_length = int(nearest_larger_binary_number(curr_length)) while new_length * vec.delta_t < 1./delta_f: new_length = new_length * 2 vec.resize(new_length) # Then convert to frequencyseries v_tilde = vec.to_frequencyseries() # Then convert frequencyseries to required length and spacing by keeping # only every nth sample if delta_f needs increasing, and cutting at # Nyquist if the max frequency is too high. # NOTE: This assumes that the input and output data is using binary # lengths. i_delta_f = v_tilde.get_delta_f() v_tilde = v_tilde.numpy() df_ratio = int(delta_f / i_delta_f) n_freq_len = int((n-1) * df_ratio +1) assert(n <= len(v_tilde)) df_ratio = int(delta_f / i_delta_f) v_tilde = v_tilde[:n_freq_len:df_ratio] vectilde = FrequencySeries(v_tilde, delta_f=delta_f, dtype=complex64) return FrequencySeries(vectilde * DYN_RANGE_FAC, delta_f=delta_f, dtype=complex64) def get_waveform(approximant, phase_order, amplitude_order, spin_order, wf_params, start_frequency, sample_rate, length, filter_rate, sky_max_template=False): delta_f = filter_rate / length if approximant in fd_approximants(): hp, hc = get_fd_waveform(wf_params, approximant=approximant, phase_order=phase_order, delta_f=delta_f, spin_order=spin_order, f_lower=start_frequency, amplitude_order=amplitude_order) elif approximant in td_approximants(): hp, hc = get_td_waveform(wf_params, approximant=approximant, phase_order=phase_order, spin_order=spin_order, delta_t=1./sample_rate, f_lower=start_frequency, amplitude_order=amplitude_order) if hasattr(wf_params, 'taper'): hp = taper_timeseries(hp, wf_params.taper) hc = taper_timeseries(hc, wf_params.taper) if not sky_max_template: hvec = generate_detector_strain(wf_params, hp, hc) return make_padded_frequency_series(hvec, filter_N, delta_f=delta_f) else: return make_padded_frequency_series(hp, filter_N, delta_f=delta_f), \ make_padded_frequency_series(hc, filter_N, delta_f=delta_f) aprs = sorted(list(set(td_approximants() + fd_approximants()))) aprs += ['FROM_TEMPLATE_BANK'] #File output Settings parser = ArgumentParser(description=__doc__) parser.add_argument("--match-file", dest="out_file", metavar="FILE", required=True, help="File to output match results") parser.add_argument("--verbose", action='store_true', default=False, help="Print verbose statements") parser.add_argument("--version", action="version", version=pycbc.version.git_verbose_msg) #Template Settings parser.add_argument("--template-file", dest="bank_file", metavar="FILE", required=True, help="SimInspiral or SnglInspiral XML file " "containing the template parameters") parser.add_argument("--total-mass-divide", type=float, help="Total mass to switch from --template-approximant to " "--highmass-approximant.") parser.add_argument("--highmass-approximant", choices=aprs, help="Waveform approximant for highmass templates.") parser.add_argument("--template-approximant", choices=aprs, required=True, help="Waveform approximant for templates") parser.add_argument("--template-phase-order", default=-1, type=int, help="PN order to use for the template phase") parser.add_argument("--template-amplitude-order", default=-1, type=int, help="PN order to use for the template amplitude") parser.add_argument("--template-spin-order", default=-1, type=int, help="PN order to use for the template spin terms") parser.add_argument("--template-start-frequency", type=float, help="Starting frequency for templates [Hz]") parser.add_argument("--template-sample-rate", type=float, help="Sample rate for templates [Hz]") #Signal Settings parser.add_argument("--signal-file", dest="sim_file", metavar="FILE", required=True, help="SimInspiral or SnglInspiral XML file " "containing the signal parameters") parser.add_argument("--signal-approximant", choices=aprs, required=True, help="Waveform approximant for signals") parser.add_argument("--signal-phase-order", default=-1, type=int, help="PN order to use for the signal phase") parser.add_argument("--signal-spin-order", default=-1, type=int, help="PN order to use for the signal spin terms") parser.add_argument("--signal-amplitude-order", default=-1, type=int, help="PN order to use for the signal amplitude") parser.add_argument("--signal-start-frequency", type=float, help="Starting frequency for signals [Hz]") parser.add_argument("--signal-sample-rate", type=float, help="Sample rate for signals [Hz]") parser.add_argument("--use-sky-location", action='store_true', help="Inject into a theoretical detector at the celestial " "North pole of a non-rotating Earth rather than overhead") #Filtering Settings parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ', type=float, required=True, help='low frequency cutoff of matched filter') parser.add_argument('--filter-low-freq-cutoff-column', metavar='NAME', type=str, help='If given, use a per-template low-frequency cutoff ' 'from column NAME of the template table instead of ' 'the value given via --filter-low-frequency-cutoff. ' 'Signals are still normalized using ' '--filter-low-frequency-cutoff and the value in ' 'column NAME must be larger than or equal to it.') parser.add_argument("--filter-sample-rate", type=float, required=True, help="Filter sample rate [Hz]") parser.add_argument("--filter-signal-length", type=int, required=True, help="Length of signal for filtering, shoud be longer " "than all waveforms and include some padding") parser.add_argument("--sky-maximization-method", required=True, choices=["precessing", "hom"]) # add PSD options pycbc.psd.insert_psd_option_group(parser, output=False) # Insert the data reading options pycbc.strain.insert_strain_option_group(parser) #hardware support pycbc.scheme.insert_processing_option_group(parser) pycbc.fft.insert_fft_option_group(parser) #Restricted maximization parser.add_argument("--mchirp-window", type=str, metavar="FRACTION", help="Ignore templates whose chirp mass deviates from " "signal's one more than given fraction. Provide two " "comma separated numbers to have different bounds " "above and below the signal's, with below bound " "listed first.") parser.add_argument("--tau0-window", type=float, metavar="TIME", default=None, help="Ignore templates whose Newtonian order chirp time " "(tau0) varies from the signals by more than the " "supplied amount. If option is not provided no " "window on tau0 is used. The " "filter-low-frequency-cutoff is used to calculate " "the value of tau0 for all cases.") options = parser.parse_args() pycbc.init_logging(options.verbose) pycbc.psd.verify_psd_options(options, parser) if options.psd_estimation: pycbc.strain.verify_strain_options(options, parser) if options.total_mass_divide and options.highmass_approximant is None: parser.error("You must provide a highmass-approximant if you want total-mass-divide.") if options.mchirp_window is None: def outside_mchirp_window(template_mchirp, signal_mchirp): return False elif ',' in options.mchirp_window: # asymmetric chirp mass window mchirp_window_lower = float(options.mchirp_window.split(",")[0]) mchirp_window_upper = float(options.mchirp_window.split(",")[1]) def outside_mchirp_window(template_mchirp, signal_mchirp): delta = (template_mchirp - signal_mchirp) / signal_mchirp return delta > mchirp_window_upper or -delta > mchirp_window_lower else: # symmetric chirp mass window mchirp_window = float(options.mchirp_window) def outside_mchirp_window(template_mchirp, signal_mchirp): return abs(signal_mchirp - template_mchirp) > \ (mchirp_window * signal_mchirp) if options.tau0_window is None: def outside_tau0_window(template_tau0, signal_tau0, window): return False else: def outside_tau0_window(template_tau0, signal_tau0, window): return abs(signal_tau0 - template_tau0) > window # If we are going to use h(t) to estimate a PSD we need h(t) if options.psd_estimation: logging.info("Obtaining h(t) for PSD generation") strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC) else: strain = None if options.template_sample_rate is not None: template_sample_rate = options.template_sample_rate else: template_sample_rate = options.filter_sample_rate if options.signal_sample_rate is not None: signal_sample_rate = options.signal_sample_rate else: signal_sample_rate = options.filter_sample_rate ctx = pycbc.scheme.from_cli(options) logging.info('Reading template bank') temp_bank = TemplateBank(options.bank_file) template_table = temp_bank.table logging.info(" %d templates", len(template_table)) logging.info('Reading simulation list') indoc = ligolw_utils.load_filename(options.sim_file, False, contenthandler=mycontenthandler) try: signal_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName) except ValueError: signal_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName) logging.info(" %d signal waveforms", len(signal_table)) logging.info("Matches will be written to %s", options.out_file) filter_N = int(options.filter_signal_length * options.filter_sample_rate) filter_n = filter_N / 2 + 1 filter_delta_f = 1.0 / float(options.filter_signal_length) logging.info("Reading and Interpolating PSD") psd = pycbc.psd.from_cli(options, filter_n, filter_delta_f, options.filter_low_frequency_cutoff, strain=strain, dyn_range_factor=pycbc.DYN_RANGE_FAC, precision='single') with ctx: pycbc.fft.from_cli(options) logging.info("Pregenerating Signals") signals = [] sig_m1 = [] sig_m2 = [] for index, signal_params in enumerate(signal_table): if options.verbose: update_progress(float(index+1)/len(signal_table)) if not options.use_sky_location: signal_params.latitude = 0. signal_params.longitude = 0. stilde = get_waveform(options.signal_approximant, options.signal_phase_order, options.signal_amplitude_order, options.signal_spin_order, signal_params, options.signal_start_frequency, signal_sample_rate, filter_N, options.filter_sample_rate) s_norm = sigmasq(stilde, psd=psd, low_frequency_cutoff=options.filter_low_frequency_cutoff) stilde /= sqrt(float(s_norm)) stilde /= psd signals.append((stilde, s_norm, [], signal_params)) sig_m1.append(signal_params.mass1) sig_m2.append(signal_params.mass2) sig_m1 = array(sig_m1) sig_m2 = array(sig_m2) sig_tau0, _ = mass1_mass2_to_tau0_tau3(sig_m1, sig_m2, options.filter_low_frequency_cutoff) sig_mchirp, _ = mass1_mass2_to_mchirp_eta(sig_m1, sig_m2) logging.info("Calculating Mchirp and Tau0") template_m1 = [] template_m2 = [] for template_params in template_table: template_m1.append(template_params.mass1) template_m2.append(template_params.mass2) template_m1 = array(template_m1) template_m2 = array(template_m2) template_tau0, _ = mass1_mass2_to_tau0_tau3(template_m1, template_m2, options.filter_low_frequency_cutoff) template_mchirp, _ = mass1_mass2_to_mchirp_eta(template_m1, template_m2) logging.info("Calculating Overlaps") flow_warned = False for index, template_params in enumerate(template_table): if options.verbose: update_progress(float(index+1)/len(template_table)) f_lower = template_params.f_lower # If not set fall back on filter low-freq cutoff if f_lower < 0.000001: f_lower = options.filter_low_frequency_cutoff if f_lower < options.filter_low_frequency_cutoff: # Not entirely clear what to do here? if not flow_warned: logging.warn("Template's flower is smaller than " "--filter-low-frequency-cutoff. Raising flower " "of template to match.") flow_warned=True f_lower = options.filter_low_frequency_cutoff h_norm = htilde = None for sidx, (stilde, s_norm, matches, signal_params) in enumerate(signals): # Check if we need to look at this check_logic = stilde is None check_logic |= outside_tau0_window(template_tau0[index], sig_tau0[sidx], options.tau0_window) check_logic |= outside_mchirp_window(template_mchirp[index], sig_mchirp[sidx]) if check_logic: matches.append(0) continue # Generate htilde if we haven't already done so if htilde is None: # FIXME: I would like to remove the approximant options and # have this entirely controlled by the template bank. # However, while we are still using the high-mass divide # in XML banks, this must be retained. try: this_approximant = template_params['approximant'] except: this_approximant = options.template_approximant if options.total_mass_divide is not None and (template_params.mass1+template_params.mass2) >= options.total_mass_divide: this_approximant = options.highmass_approximant hplus, hcross = get_waveform(this_approximant, options.template_phase_order, options.template_amplitude_order, options.template_spin_order, template_params, options.template_start_frequency, template_sample_rate, filter_N, options.filter_sample_rate, sky_max_template=True) hp_norm = sigmasq(hplus, psd=psd, low_frequency_cutoff=f_lower) hc_norm = sigmasq(hcross, psd=psd, low_frequency_cutoff=f_lower) hplus /= sqrt(float(hp_norm)) hcross /= sqrt(float(hc_norm)) hpc_corr = overlap_cplx(hplus, hcross, psd=psd, low_frequency_cutoff=options.filter_low_frequency_cutoff, normalized=False) hpc_corr_R = real(hpc_corr) htilde=1 I_plus = matched_filter(hplus, stilde, low_frequency_cutoff=options.filter_low_frequency_cutoff, sigmasq=1.) I_cross = matched_filter(hcross, stilde, low_frequency_cutoff=options.filter_low_frequency_cutoff, sigmasq=1.) if options.sky_maximization_method == 'precessing': det_stat = compute_max_snr_over_sky_loc_stat\ (I_plus, I_cross, hpc_corr_R, hpnorm=1., hcnorm=1., thresh=0.1, analyse_slice=slice(0,len(I_plus.data))) elif options.sky_maximization_method == 'hom': det_stat = compute_max_snr_over_sky_loc_stat_no_phase\ (I_plus, I_cross, hpc_corr_R, hpnorm=1., hcnorm=1., thresh=0.1, analyse_slice=slice(0,len(I_plus.data))) else: err_msg = "I really shouldn't be here! Who gone broked me?" raise ValueError(err_msg) i = argmax(det_stat.data) o = det_stat[i] matches.append(o) logging.info("Determining maximum overlaps and outputting results") # Find the maximum overlap in the bank and output to a file with open(options.out_file, "w") as fout: for i, (stilde, s_norm, matches, sim_template) in enumerate(signals): match_str = "%5.5f " % max(matches) match_str += " " + options.bank_file match_str += " " + str(matches.index(max(matches))) match_str += " " + options.sim_file match_str += " %d" % i match_str += " %5.5f\n" % s_norm fout.write(match_str)