Revision 6fd711f62b16977a90c97b13f81fe90f24f38f1a authored by Sebastian Khan on 25 May 2017, 09:45:24 UTC, committed by GitHub on 25 May 2017, 09:45:24 UTC
* add coaphase to hwinj allows you to specify the reference orbital phase when using pycbc_generate_hwinj coaphase is the phiRef parameter in LALSimulation but it's called coaphase in sim_inspiral tables. * fix naming bug * set default value to zero * update help message * fix build and implement ian's comment
1 parent b48e9b1
pycbc_banksim
#! /usr/bin/env python
# Copyright (C) 2012 Alex Nitz
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
"""Calculate the fitting factors of simulated signals with a template bank."""
import sys
import logging
from numpy import complex64, float32, array
from argparse import ArgumentParser
from pycbc_glue.ligolw import utils as ligolw_utils
from pycbc_glue.ligolw import table, lsctables
from pycbc_glue.ligolw.ligolw import LIGOLWContentHandler
class mycontenthandler(LIGOLWContentHandler):
pass
lsctables.use_in(mycontenthandler)
from pycbc.pnutils import mass1_mass2_to_mchirp_eta, f_SchwarzISCO
from pycbc.pnutils import mass1_mass2_to_tau0_tau3
from pycbc.pnutils import nearest_larger_binary_number
from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants
from pycbc.waveform.utils import taper_timeseries
from pycbc import DYN_RANGE_FAC
from pycbc.types import FrequencySeries, TimeSeries, zeros, real_same_precision_as, complex_same_precision_as
from pycbc.filter import match, sigmasq, resample_to_delta_t
from math import ceil, log
import pycbc.psd, pycbc.scheme, pycbc.fft, pycbc.strain, pycbc.version
from pycbc.detector import overhead_antenna_pattern as generate_fplus_fcross
from pycbc.waveform import TemplateBank
def update_progress(progress):
bar = '#' * (int(progress*100)/2) + ' ' * (50-int(progress*100)/2)
print '[{0}] {1:.2%}\r\r'.format(bar, progress),
if progress == 1:
print('Done' + ' ' * 70)
sys.stdout.flush()
## Remove the need for these functions ########################################
def generate_detector_strain(template_params, h_plus, h_cross):
latitude = 0
longitude = 0
polarization = 0
if hasattr(template_params, 'latitude'):
latitude = template_params.latitude
if hasattr(template_params, 'longitude'):
longitude = template_params.longitude
if hasattr(template_params, 'polarization'):
polarization = template_params.polarization
f_plus, f_cross = generate_fplus_fcross(longitude, latitude, polarization)
return h_plus * f_plus + h_cross * f_cross
def make_padded_frequency_series(vec, filter_N=None, delta_f=None):
"""Convert vec (TimeSeries or FrequencySeries) to a FrequencySeries. If
filter_N and/or delta_f are given, the output will take those values. If
not told otherwise the code will attempt to pad a timeseries first such that
the waveform will not wraparound. However, if delta_f is specified to be
shorter than the waveform length then wraparound *will* be allowed.
"""
if filter_N is None:
power = ceil(log(len(vec), 2)) + 1
N = 2 ** power
else:
N = filter_N
n = N / 2 + 1
if isinstance(vec, FrequencySeries):
vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)),
delta_f=1.0, copy=False)
if len(vectilde) < len(vec):
cplen = len(vectilde)
else:
cplen = len(vec)
vectilde[0:cplen] = vec[0:cplen]
delta_f = vec.delta_f
elif isinstance(vec, TimeSeries):
# First determine if the timeseries is too short for the specified df
# and increase if necessary
curr_length = len(vec)
new_length = int(nearest_larger_binary_number(curr_length))
while new_length * vec.delta_t < 1./delta_f:
new_length = new_length * 2
vec.resize(new_length)
# Then convert to frequencyseries
v_tilde = vec.to_frequencyseries()
# Then convert frequencyseries to required length and spacing by keeping
# only every nth sample if delta_f needs increasing, and cutting at
# Nyquist if the max frequency is too high.
# NOTE: This assumes that the input and output data is using binary
# lengths.
i_delta_f = v_tilde.get_delta_f()
v_tilde = v_tilde.numpy()
df_ratio = int(delta_f / i_delta_f)
n_freq_len = int((n-1) * df_ratio +1)
assert(n <= len(v_tilde))
df_ratio = int(delta_f / i_delta_f)
v_tilde = v_tilde[:n_freq_len:df_ratio]
vectilde = FrequencySeries(v_tilde, delta_f=delta_f, dtype=complex64)
return FrequencySeries(vectilde * DYN_RANGE_FAC, delta_f=delta_f,
dtype=complex64)
def get_waveform(approximant, phase_order, amplitude_order, spin_order,
wf_params, start_frequency, sample_rate, length,
filter_rate):
delta_f = filter_rate / length
if approximant in fd_approximants():
hp, hc = get_fd_waveform(wf_params, approximant=approximant,
phase_order=phase_order, delta_f=delta_f,
spin_order=spin_order,
f_lower=start_frequency,
amplitude_order=amplitude_order)
hvec = generate_detector_strain(wf_params, hp, hc)
elif approximant in td_approximants():
hp, hc = get_td_waveform(wf_params,
approximant=approximant,
phase_order=phase_order,
spin_order=spin_order,
delta_t=1./sample_rate,
f_lower=start_frequency,
amplitude_order=amplitude_order)
if hasattr(wf_params, 'taper'):
hp = taper_timeseries(hp, wf_params.taper)
hc = taper_timeseries(hc, wf_params.taper)
hvec = generate_detector_strain(wf_params, hp, hc)
return make_padded_frequency_series(hvec, filter_N, delta_f=delta_f)
aprs = sorted(list(set(td_approximants() + fd_approximants())))
#File output Settings
parser = ArgumentParser(description=__doc__)
parser.add_argument("--match-file", dest="out_file", metavar="FILE",
required=True, help="File to output match results")
parser.add_argument("--verbose", action='store_true', default=False,
help="Print verbose statements")
parser.add_argument("--version", action="version",
version=pycbc.version.git_verbose_msg)
#Template Settings
parser.add_argument("--template-file", dest="bank_file", metavar="FILE",
required=True, help="SimInspiral or SnglInspiral XML file "
"containing the template parameters")
parser.add_argument("--total-mass-divide", type=float,
help="Total mass to switch from --template-approximant to "
"--highmass-approximant.")
parser.add_argument("--highmass-approximant", choices=aprs,
help="Waveform approximant for highmass templates.")
parser.add_argument("--template-approximant", choices=aprs, required=True,
help="Waveform approximant for templates")
parser.add_argument("--template-phase-order", default=-1, type=int,
help="PN order to use for the template phase")
parser.add_argument("--template-amplitude-order", default=-1, type=int,
help="PN order to use for the template amplitude")
parser.add_argument("--template-spin-order", default=-1, type=int,
help="PN order to use for the template spin terms")
parser.add_argument("--template-start-frequency", type=float,
help="Starting frequency for templates [Hz]")
parser.add_argument("--template-sample-rate", type=float,
help="Sample rate for templates [Hz]")
#Signal Settings
parser.add_argument("--signal-file", dest="sim_file", metavar="FILE",
required=True, help="SimInspiral or SnglInspiral XML file "
"containing the signal parameters")
parser.add_argument("--signal-approximant", choices=aprs, required=True,
help="Waveform approximant for signals")
parser.add_argument("--signal-phase-order", default=-1, type=int,
help="PN order to use for the signal phase")
parser.add_argument("--signal-spin-order", default=-1, type=int,
help="PN order to use for the signal spin terms")
parser.add_argument("--signal-amplitude-order", default=-1, type=int,
help="PN order to use for the signal amplitude")
parser.add_argument("--signal-start-frequency", type=float,
help="Starting frequency for signals [Hz]")
parser.add_argument("--signal-sample-rate", type=float,
help="Sample rate for signals [Hz]")
parser.add_argument("--use-sky-location", action='store_true',
help="Inject into a theoretical detector at the celestial "
"North pole of a non-rotating Earth rather than overhead")
#Filtering Settings
parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ', type=float,
required=True, help='low frequency cutoff of matched filter')
parser.add_argument("--filter-sample-rate", type=float, required=True,
help="Filter sample rate [Hz]")
parser.add_argument("--filter-signal-length", type=int, required=True,
help="Length of signal for filtering, shoud be longer "
"than all waveforms and include some padding")
# add PSD options
pycbc.psd.insert_psd_option_group(parser, output=False)
# Insert the data reading options
pycbc.strain.insert_strain_option_group(parser)
#hardware support
pycbc.scheme.insert_processing_option_group(parser)
pycbc.fft.insert_fft_option_group(parser)
#Restricted maximization
parser.add_argument("--mchirp-window", type=str, metavar="FRACTION",
help="Ignore templates whose chirp mass deviates from "
"signal's one more than given fraction. Provide two "
"comma separated numbers to have different bounds "
"above and below the signal's, with below bound "
"listed first.")
parser.add_argument("--tau0-window", type=float, metavar="TIME", default=None,
help="Ignore templates whose Newtonian order chirp time "
"(tau0) varies from the signals by more than the "
"supplied amount. If option is not provided no "
"window on tau0 is used. The "
"filter-low-frequency-cutoff is used to calculate "
"the value of tau0 for all cases. Provided in units "
"of seconds.")
options = parser.parse_args()
pycbc.init_logging(options.verbose)
pycbc.psd.verify_psd_options(options, parser)
if options.psd_estimation:
pycbc.strain.verify_strain_options(options, parser)
if options.total_mass_divide and options.highmass_approximant is None:
parser.error("You must provide a highmass-approximant if you want total-mass-divide.")
if options.mchirp_window is None:
def outside_mchirp_window(template_mchirp, signal_mchirp):
return False
elif ',' in options.mchirp_window:
# asymmetric chirp mass window
mchirp_window_lower = float(options.mchirp_window.split(",")[0])
mchirp_window_upper = float(options.mchirp_window.split(",")[1])
def outside_mchirp_window(template_mchirp, signal_mchirp):
delta = (template_mchirp - signal_mchirp) / signal_mchirp
return delta > mchirp_window_upper or -delta > mchirp_window_lower
else:
# symmetric chirp mass window
mchirp_window = float(options.mchirp_window)
def outside_mchirp_window(template_mchirp, signal_mchirp):
return abs(signal_mchirp - template_mchirp) > \
(mchirp_window * signal_mchirp)
if options.tau0_window is None:
def outside_tau0_window(template_tau0, signal_tau0, window):
return False
else:
def outside_tau0_window(template_tau0, signal_tau0, window):
return abs(signal_tau0 - template_tau0) > window
# If we are going to use h(t) to estimate a PSD we need h(t)
if options.psd_estimation:
logging.info("Obtaining h(t) for PSD generation")
strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC)
else:
strain = None
if options.template_sample_rate is not None:
template_sample_rate = options.template_sample_rate
else:
template_sample_rate = options.filter_sample_rate
if options.signal_sample_rate is not None:
signal_sample_rate = options.signal_sample_rate
else:
signal_sample_rate = options.filter_sample_rate
ctx = pycbc.scheme.from_cli(options)
logging.info('Reading template bank')
temp_bank = TemplateBank(options.bank_file)
template_table = temp_bank.table
logging.info(" %d templates", len(template_table))
logging.info('Reading simulation list')
indoc = ligolw_utils.load_filename(options.sim_file, False,
contenthandler=mycontenthandler)
try:
signal_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)
except ValueError:
signal_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName)
logging.info(" %d signal waveforms", len(signal_table))
logging.info("Matches will be written to %s", options.out_file)
filter_N = int(options.filter_signal_length * options.filter_sample_rate)
filter_n = filter_N / 2 + 1
filter_delta_f = 1.0 / float(options.filter_signal_length)
logging.info("Reading and Interpolating PSD")
psd = pycbc.psd.from_cli(options, filter_n, filter_delta_f,
options.filter_low_frequency_cutoff, strain=strain,
dyn_range_factor=pycbc.DYN_RANGE_FAC,
precision='single')
with ctx:
pycbc.fft.from_cli(options)
logging.info("Pregenerating Signals")
signals = []
# Used for getting mchirp/tau0 later
sig_m1 = []
sig_m2 = []
for index, signal_params in enumerate(signal_table):
if options.verbose:
update_progress(float(index+1)/len(signal_table))
if not options.use_sky_location:
signal_params.latitude = 0.
signal_params.longitude = 0.
stilde = get_waveform(options.signal_approximant,
options.signal_phase_order,
options.signal_amplitude_order,
options.signal_spin_order,
signal_params,
options.signal_start_frequency,
signal_sample_rate,
filter_N, options.filter_sample_rate)
s_norm = sigmasq(stilde, psd=psd,
low_frequency_cutoff=options.filter_low_frequency_cutoff)
stilde /= psd
signals.append((stilde, s_norm, [], signal_params))
sig_m1.append(signal_params.mass1)
sig_m2.append(signal_params.mass2)
sig_m1 = array(sig_m1)
sig_m2 = array(sig_m2)
sig_tau0, _ = mass1_mass2_to_tau0_tau3(sig_m1, sig_m2,
options.filter_low_frequency_cutoff)
sig_mchirp, _ = mass1_mass2_to_mchirp_eta(sig_m1, sig_m2)
logging.info("Calculating Mchirp and Tau0")
template_m1 = array([tp.mass1 for tp in template_table])
template_m2 = array([tp.mass2 for tp in template_table])
template_tau0, _ = mass1_mass2_to_tau0_tau3(template_m1, template_m2,
options.filter_low_frequency_cutoff)
template_mchirp, _ = mass1_mass2_to_mchirp_eta(template_m1, template_m2)
logging.info("Calculating Overlaps")
flow_warned = False
for index, template_params in enumerate(template_table):
if options.verbose:
update_progress(float(index+1)/len(template_table))
f_lower = template_params.f_lower
# If not set fall back on filter low-freq cutoff
if f_lower < 0.000001:
f_lower = options.filter_low_frequency_cutoff
if f_lower < options.filter_low_frequency_cutoff:
# Not entirely clear what to do here?
if not flow_warned:
logging.warn("Template's flower is smaller than "
"--filter-low-frequency-cutoff. Raising flower "
"of template to match.")
flow_warned=True
f_lower = options.filter_low_frequency_cutoff
h_norm = htilde = None
for sidx, (stilde, s_norm, matches, signal_params) in enumerate(signals):
# Check if we need to look at this
check_logic = stilde is None
check_logic |= outside_tau0_window(template_tau0[index],
sig_tau0[sidx],
options.tau0_window)
check_logic |= outside_mchirp_window(template_mchirp[index],
sig_mchirp[sidx])
if check_logic:
matches.append(0)
continue
# Generate htilde if we haven't already done so
if htilde is None:
# FIXME: I would like to remove the approximant options and
# have this entirely controlled by the template bank.
# However, while we are still using the high-mass divide
# in XML banks, this must be retained.
try:
this_approximant = template_params['approximant']
except:
this_approximant = options.template_approximant
if options.total_mass_divide is not None and (template_params.mass1+template_params.mass2) >= options.total_mass_divide:
this_approximant = options.highmass_approximant
htilde = get_waveform(this_approximant,
options.template_phase_order,
options.template_amplitude_order,
options.template_spin_order,
template_params,
options.template_start_frequency,
template_sample_rate,
filter_N, options.filter_sample_rate)
h_norm = sigmasq(htilde, psd=psd, low_frequency_cutoff=f_lower)
o, i = match(htilde, stilde, v1_norm=h_norm, v2_norm=s_norm,
low_frequency_cutoff=f_lower)
matches.append(o)
logging.info("Determining maximum overlaps and outputting results")
# Find the maximum overlap in the bank and output to a file
with open(options.out_file, "w") as fout:
for i, (stilde, s_norm, matches, sim_template) in enumerate(signals):
match_str = "%5.5f " % max(matches)
match_str += " " + options.bank_file
match_str += " " + str(matches.index(max(matches)))
match_str += " " + options.sim_file
match_str += " %d" % i
match_str += " %5.5f\n" % s_norm
fout.write(match_str)
Computing file changes ...