https://github.com/gwastro/pycbc
Raw File
Tip revision: e25f027d8f7bcba252e520a989a0db44c5c1c041 authored by Duncan Brown on 20 October 2015, 20:24:10 UTC
Merge pull request #516 from lppekows/static_builds
Tip revision: e25f027
pycbc_banksim
#! /usr/bin/env python
# Copyright (C) 2012  Alex Nitz
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
from time import time

start = time()
elapsed_time = lambda: time()-start

import sys
from numpy import complex64, float32
from argparse import ArgumentParser
from glue.ligolw import utils as ligolw_utils
from glue.ligolw import table, lsctables
from glue.ligolw.ligolw import LIGOLWContentHandler
class mycontenthandler(LIGOLWContentHandler):
    pass
lsctables.use_in(mycontenthandler)

from pycbc.pnutils import mass1_mass2_to_mchirp_eta, f_SchwarzISCO
from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants
from pycbc import DYN_RANGE_FAC
from pycbc.types import FrequencySeries, TimeSeries, zeros, real_same_precision_as, complex_same_precision_as
from pycbc.filter import match, sigmasq, resample_to_delta_t
from pycbc.fft import fft
from math import cos, sin, ceil, log
import pycbc.psd, pycbc.scheme, pycbc.fft, pycbc.strain


def update_progress(progress):
    print '\r\r[{0}] {1:.2%}'.format('#'*(int(progress*100)/2)+' '*(50-int(progress*100)/2), progress),
    if progress == 100:
        print "Done"
    sys.stdout.flush()

## Remove the need for these functions ########################################

def generate_fplus_fcross(latitude,longitude,polarization):
    f_plus = - (1.0/2.0) * (1.0 + cos(latitude)*cos(latitude)) * cos (2.0 * longitude) * cos (2.0 * polarization) - cos(latitude) * sin(2.0*longitude) * sin (2.0 * polarization)
    f_cross=  (1.0/2.0) * (1.0 + cos(latitude)*cos(latitude)) * cos (2.0 * longitude) * sin (2.0* polarization) - cos (latitude) * sin(2.0*longitude) * cos (2.0 * polarization)
    return f_plus, f_cross
    
def generate_detector_strain(template_params, h_plus, h_cross):
    latitude = 0 
    longitude = 0 
    polarization = 0 

    if hasattr(template_params, 'latitude'):
        latitude = template_params.latitude
    if hasattr(template_params, 'longitude'):
        longitude = template_params.longitude
    if hasattr(template_params, 'polarization'):
        polarization = template_params.polarization

    f_plus, f_cross = generate_fplus_fcross(latitude, longitude, polarization)

    return f_plus * h_plus + f_cross * h_cross

def make_padded_frequency_series(vec, filter_N=None):
    """Pad a TimeSeries with a length of zeros greater than its length, such
    that the total length is the closest power of 2. This prevents the effects 
    of wraparound.
    """
    if filter_N is None:
        power = ceil(log(len(vec), 2)) + 1
        N = 2 ** power
    else:
        N = filter_N
    n = N / 2 + 1

    if isinstance(vec, FrequencySeries):
        vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)),
                                   delta_f=1.0, copy=False)
        if len(vectilde) < len(vec):
            cplen = len(vectilde)
        else:
            cplen = len(vec)
        vectilde[0:cplen] = vec[0:cplen]
        delta_f = vec.delta_f

    elif isinstance(vec, TimeSeries):
        vec_pad = TimeSeries(zeros(N), delta_t=vec.delta_t,
                         dtype=real_same_precision_as(vec))
        vec_pad[0:len(vec)] = vec
        delta_f = 1.0 / (vec.delta_t * N)
        vectilde = FrequencySeries(zeros(n), delta_f=1.0,
                               dtype=complex_same_precision_as(vec))
        fft(vec_pad, vectilde)

    return FrequencySeries(vectilde * DYN_RANGE_FAC, delta_f=delta_f,
                           dtype=complex64)

def get_waveform(approximant, phase_order, amplitude_order, spin_order,
                 template_params, start_frequency, sample_rate, length,
                 filter_rate):

    if approximant in td_approximants():
        hplus, hcross = get_td_waveform(template_params,
                                        approximant=approximant,
                                        phase_order=phase_order,
                                        spin_order=spin_order,
                                        delta_t=1./sample_rate,
                                        f_lower=start_frequency,
                                        amplitude_order=amplitude_order)
        hvec = generate_detector_strain(template_params, hplus, hcross)

        if filter_rate != sample_rate:
            delta_t = 1.0 / filter_rate
            hvec = resample_to_delta_t(hvec, delta_t)

    elif approximant in fd_approximants():
        delta_f = filter_rate / length
        hp, hc = get_fd_waveform(template_params, approximant=approximant,
                                 phase_order=phase_order, delta_f=delta_f,
                                 spin_order=spin_order,
                                 f_lower=start_frequency,
                                 amplitude_order=amplitude_order) 
        hvec = generate_detector_strain(template_params, hp, hc)

    return make_padded_frequency_series(hvec, filter_N)

aprs = list(set(td_approximants() + fd_approximants()))

# returns true if template_mchirp is more than w*signal_mchirp above or below signal_mchirp
def outside_mchirp_window(template, signal, w):
    template_mchirp, et = mass1_mass2_to_mchirp_eta(template.mass1,
                                                    template.mass2)
    signal_mchirp, et = mass1_mass2_to_mchirp_eta(signal.mass1, signal.mass2)
    return abs(signal_mchirp - template_mchirp) > (w * signal_mchirp)

# returns true if template_mchirp is more than w*signal_mchirp above signal_mchirp
def above_mchirp_window(template, signal, w):
    template_mchirp, et = mass1_mass2_to_mchirp_eta(template.mass1,
                                                    template.mass2)
    signal_mchirp, et = mass1_mass2_to_mchirp_eta(signal.mass1, signal.mass2)
    return template_mchirp - signal_mchirp > (w * signal_mchirp)

# returns true if template_mchirp is more than w*signal_mchirp below signal_mchirp
def below_mchirp_window(template, signal, w):
    template_mchirp, et = mass1_mass2_to_mchirp_eta(template.mass1,
                                                    template.mass2)
    signal_mchirp, et = mass1_mass2_to_mchirp_eta(signal.mass1, signal.mass2)
    return signal_mchirp - template_mchirp > (w * signal_mchirp)

#File output Settings
parser = ArgumentParser()
parser.add_argument("--match-file", dest="out_file", help="file to output match results", metavar="FILE")
parser.add_argument("--verbose", action='store_true', default=False, help="Print verbose statements")

#Template Settings
parser.add_argument("--template-file", dest="bank_file", help="SimInspiral or SnglInspiral XML file containing the template parameters", metavar="FILE")
parser.add_argument("--total-mass-divide", type=float, help="Total mass to switch from --template-approximant to --highmass-approximant.")
parser.add_argument("--highmass-approximant", help="Waveform approximant for highmass templates.")
parser.add_argument("--template-approximant", help="Waveform approximant for templates", choices=aprs)
parser.add_argument("--template-phase-order", help="PN order to use for the template phase", default=-1, type=int)
parser.add_argument("--template-amplitude-order", help="PN order to use for the template amplitude", default=-1, type=int)
parser.add_argument("--template-spin-order", help="PN order to use for the template spin terms", default=-1, type=int)
parser.add_argument("--template-start-frequency", help="Starting frequency for templates [Hz]", type=float)
parser.add_argument("--template-sample-rate", help="Sample rate for templates [Hz]", type=float)

#Signal Settings
parser.add_argument("--signal-file", dest="sim_file", help="SimInspiral or SnglInspiral XML file containing the signal parameters", metavar="FILE")
parser.add_argument("--signal-approximant", help="Waveform approximant for signals", choices=aprs)
parser.add_argument("--signal-phase-order", help="PN order to use for the signal phase", default=-1, type=int)
parser.add_argument("--signal-spin-order", help="PN order to use for the signal spin terms", default=-1, type=int)
parser.add_argument("--signal-amplitude-order", help="PN order to use for the signal amplitude", default=-1, type=int)
parser.add_argument("--signal-start-frequency", help="Starting frequency for signals [Hz]", type=float)
parser.add_argument("--signal-sample-rate", help="Sample rate for signals [Hz]", type=float)
parser.add_argument("--use-sky-location", help="Inject into a theoretical detector at the celestial North pole of a non-rotating Earth rather than overhead", action='store_true')

#Filtering Settings
parser.add_argument('--filter-low-frequency-cutoff', metavar='FREQ', help='low frequency cutoff of matched filter', type=float)
parser.add_argument('--filter-high-freq-cutoff-isco-factor', metavar='FLOAT', help='High frequency cutoff of matched filter as a fraction of the injection\'s ISCO', type=float)
parser.add_argument("--filter-sample-rate", help="Filter sample rate [Hz]", type=float)
parser.add_argument("--filter-signal-length", help="Length of signal for filtering, shoud be longer than all waveforms and include some padding", type=int)

# add PSD options
pycbc.psd.insert_psd_option_group(parser, output=False)

# Insert the data reading options
pycbc.strain.insert_strain_option_group(parser)

#hardware support
pycbc.scheme.insert_processing_option_group(parser)
pycbc.fft.insert_fft_option_group(parser)

#Restricted maximization
parser.add_argument("--mchirp-window", type=str, metavar="FRACTION", help="Ignore templates whose chirp mass deviates from signal's one more than given fraction. Provide two comma separated numbers to have different bounds above and below the signal's, with below bound listed first.")
options = parser.parse_args()
#Split mchirp_window depending on whether it contains a comma
mchirp_list = True if ',' in options.mchirp_window else False
if mchirp_list:
     mchirp_window_lower = float(options.mchirp_window.split(",")[0])
     mchirp_window_upper = float(options.mchirp_window.split(",")[1])
else:
     mchirp_window_equal = float(options.mchirp_window)


template_sample_rate = options.filter_sample_rate
signal_sample_rate = options.filter_sample_rate

pycbc.psd.verify_psd_options(options, parser)

if options.psd_estimation:
    pycbc.strain.verify_strain_options(options, parser)

# If we are going to use h(t) to estimate a PSD we need h(t)
if options.psd_estimation:
    logging.info("Obtaining h(t) for PSD generation")
    strain = pycbc.strain.from_cli(options, pycbc.DYN_RANGE_FAC)
else:
    strain = None

if options.total_mass_divide and options.highmass_approximant is None:
    parser.error("You must provide a highmass-approximant if you want total-mass-divide.")

if options.template_sample_rate:
    template_sample_rate = options.template_sample_rate

if options.signal_sample_rate:
    signal_sample_rate = options.signal_sample_rate

ctx = pycbc.scheme.from_cli(options)

if options.verbose:
  print "Options read and verified, beginning banksim at %fs" %(elapsed_time())

# Load in the template bank file
indoc = ligolw_utils.load_filename(options.bank_file, False, contenthandler=mycontenthandler)
try :
    template_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName) 
except ValueError:
    template_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)

# Load in the simulation list
indoc = ligolw_utils.load_filename(options.sim_file, False, contenthandler=mycontenthandler)
try:
    signal_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName) 
except ValueError:
    signal_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName)

if options.verbose:
  print "Bank and simulation files read at %fs" %(elapsed_time())
  print "Number of Signal Waveforms: ",len(signal_table)
  print "Number of Templates       : ",len(template_table)
  print "Matches will be written to " + options.out_file
  print "Recovered templates will be written to " + options.out_file+".found"

filter_N = int(options.filter_signal_length * options.filter_sample_rate)
filter_n = filter_N / 2 + 1
filter_delta_f = 1.0 / float(options.filter_signal_length)

if options.verbose:
  print("Reading and Interpolating PSD")
psd = pycbc.psd.from_cli(options, filter_n,  
                         filter_delta_f, options.filter_low_frequency_cutoff, strain=strain,
                         dyn_range_factor=pycbc.DYN_RANGE_FAC, precision='single')



if options.verbose:
  print("PSD interpolated at %fs" %(elapsed_time()))
  print("Pregenerating Signals")
  
with ctx: 
    pycbc.fft.from_cli(options)

    signals = []
    index = 0 
    for signal_params in signal_table:
        index += 1
        if options.verbose:
            update_progress(float(index)/len(signal_table))
        if not options.use_sky_location:
            signal_params.latitude = 0.
            signal_params.longitude = 0.
        stilde = get_waveform(options.signal_approximant, 
                      options.signal_phase_order, 
                      options.signal_amplitude_order, 
                      options.signal_spin_order,
                      signal_params, 
                      options.signal_start_frequency, 
                      signal_sample_rate, 
                      filter_N, options.filter_sample_rate)
        if options.filter_high_freq_cutoff_isco_factor is not None:
            high_freq_cutoff = options.filter_high_freq_cutoff_isco_factor \
                    * f_SchwarzISCO(signal_params.mass1 + signal_params.mass2)
        else:
            high_freq_cutoff = None
        s_norm = sigmasq(stilde, psd=psd, 
                low_frequency_cutoff=options.filter_low_frequency_cutoff,
                high_frequency_cutoff=high_freq_cutoff)
        stilde /= psd
        signals.append((stilde, s_norm, [], signal_params,
                        high_freq_cutoff))

    if options.verbose:
      print
      print("Signals pregenerated at %fs" %(elapsed_time()))
      print("Calculating Overlaps")

    index = 0 
    # Calculate the overlaps
    for template_params in template_table:
        index += 1
        if options.verbose:
          update_progress(float(index)/len(template_table))
        h_norm = htilde = None
        for stilde, s_norm, matches, signal_params, signal_high_freq_cutoff in signals:
            # Check if we need to look at this
            if stilde is None:
                matches.append(0)
                continue
            # Don't look if outside mchirp_window_equal range
            elif mchirp_list is False and outside_mchirp_window(
                    template_params, signal_params, mchirp_window_equal):
                matches.append(0)
                continue
            # Don't look if below mchirp_window_lower range
            elif mchirp_list is True and below_mchirp_window(
                    template_params, signal_params, mchirp_window_lower):
                matches.append(0)
                continue
            # Don't look if above mchirp_window_upper range
            elif mchirp_list is True and above_mchirp_window(
                    template_params, signal_params, mchirp_window_upper):
                matches.append(0)
                continue


            # Generate htilde if we haven't already done so
            if htilde is None:
                this_approximant = options.template_approximant
                if options.total_mass_divide is not None and (template_params.mass1+template_params.mass2) >= options.total_mass_divide:
                        this_approximant = options.highmass_approximant
                htilde = get_waveform(this_approximant,
                                      options.template_phase_order,
                                      options.template_amplitude_order,
                                      options.template_spin_order,
                                      template_params,
                                      options.template_start_frequency,
                                      template_sample_rate,
                                      filter_N, options.filter_sample_rate)
            
            h_norm = sigmasq(htilde, psd=psd,
                    low_frequency_cutoff=options.filter_low_frequency_cutoff,
                    high_frequency_cutoff=signal_high_freq_cutoff)

            o, i = match(htilde, stilde, v1_norm=h_norm, v2_norm=s_norm,
                    low_frequency_cutoff=options.filter_low_frequency_cutoff,
                    high_frequency_cutoff=signal_high_freq_cutoff)
            matches.append(o)

if options.verbose:
  print
  print("Overlaps finished at %fs" %(elapsed_time()))
  print("Determining maximum overlaps and outputting results")

# Find the maximum overlap in the bank and output to a file
with open(options.out_file, "w") as fout:
    for i, (stilde, s_norm, matches, sim_template, hfc) in enumerate(signals):
        match_str = "%5.5f " % max(matches)
        match_str += " " + options.bank_file
        match_str += " " + str(matches.index(max(matches)))
        match_str += " " + options.sim_file
        match_str += " %d" % i
        match_str += " %5.5f\n" % s_norm
        fout.write(match_str)
back to top