https://github.com/gwastro/pycbc
Raw File
Tip revision: 399438de31cd182a37994c9949721cb9c42ae59e authored by Duncan Brown on 18 September 2015, 19:06:34 UTC
Merge pull request #296 from lppekows/general_urls
Tip revision: 399438d
pycbc_faithsim
#! /usr/bin/env python
## Copyright (C) 2012  Alex Nitz
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
import sys
from numpy import loadtxt,complex64,float32
from optparse import OptionParser
from glue.ligolw import utils as ligolw_utils
from glue.ligolw import table, lsctables, ligolw
from math import pow

from scipy.interpolate import interp1d

from pycbc.pnutils import mass1_mass2_to_mchirp_eta
from pycbc.waveform import get_td_waveform, get_fd_waveform, td_approximants, fd_approximants
from pycbc import DYN_RANGE_FAC
from pycbc.types import FrequencySeries, TimeSeries, zeros, real_same_precision_as, complex_same_precision_as
from pycbc.filter import match, overlap, sigma
from pycbc.scheme import CPUScheme, CUDAScheme
from pycbc.fft import fft
import pycbc.psd 

class ContentHandler(ligolw.LIGOLWContentHandler):
    pass
lsctables.use_in(ContentHandler)

def update_progress(progress):
    print '\r\r[{0}] {1}%'.format('#'*(progress/2)+' '*(50-progress/2), progress),
    if progress == 100:
        print "Done"
    sys.stdout.flush()

## Remove the need for these functions ########################################
def make_padded_frequency_series(vec,filter_N=None):
    """Pad a TimeSeries with a length of zeros greater than its length, such
    that the total length is the closest power of 2. This prevents the effects 
    of wraparound.
    """
    if filter_N is None:
        power = ceil(log(len(vec),2))+1
        N = 2 ** power
    else:
        N = filter_N
    n = N/2+1    
    
   
    if isinstance(vec,FrequencySeries):
        vectilde = FrequencySeries(zeros(n, dtype=complex_same_precision_as(vec)),
                                   delta_f=1.0,copy=False)
	if len(vectilde) < len(vec):
	    cplen = len(vectilde)
        else:
            cplen = len(vec)
        vectilde[0:cplen] = vec[0:cplen]  
        delta_f = vec.delta_f
    
        
    if isinstance(vec,TimeSeries):  
        vec_pad = TimeSeries(zeros(N),delta_t=vec.delta_t,
                         dtype=real_same_precision_as(vec))
        vec_pad[0:len(vec)] = vec   
        delta_f = 1.0/(vec.delta_t*N)
        vectilde = FrequencySeries(zeros(n),delta_f=1.0, 
                               dtype=complex_same_precision_as(vec))
        fft(vec_pad,vectilde)
        
    vectilde = FrequencySeries(vectilde * DYN_RANGE_FAC,delta_f=delta_f,dtype=complex64)
    return vectilde

def get_waveform(approximant, phase_order, amplitude_order, spin_order, template_params, start_frequency, sample_rate, length):

    if approximant in td_approximants():
        hplus,hcross = get_td_waveform(template_params, approximant=approximant, spin_order=spin_order,
                                   phase_order=phase_order, delta_t=1.0 / sample_rate,
                                   f_lower=start_frequency, amplitude_order=amplitude_order) 

    elif approximant in fd_approximants():
        delta_f = sample_rate / length
        hplus, hcross = get_fd_waveform(template_params, approximant=approximant, spin_order=spin_order,
                               phase_order=phase_order, delta_f=delta_f,
                               f_lower=start_frequency, amplitude_order=amplitude_order)     

    hvec = hplus

    htilde = make_padded_frequency_series(hvec,filter_N)

    return htilde

###############################################################################

aprs = list(set(td_approximants() + fd_approximants()))
psd_names = pycbc.psd.get_lalsim_psd_list()

print "STARTING"

#File I/O Settings
parser = OptionParser()
parser.add_option("--param-file", dest="bank_file", help="Sngl or Sim Inspiral Table containing waveform parameters", metavar="FILE")
parser.add_option("--match-file", dest="out_file",help="file to output match results", metavar="FILE")

parser.add_option("--asd-file", dest="asd_file",help="ASD data", metavar="FILE")
parser.add_option("--psd", dest="psd", help="Analytic PSD model from LALSimulation", choices=psd_names)


#Waveform generation Settings
parser.add_option("--waveform1-approximant",help="waveform1  Approximant Name: " + str(aprs), choices = aprs)
parser.add_option("--waveform1-phase-order",help="PN order to use for the phase",default=-1,type=int) 
parser.add_option("--waveform1-amplitude-order",help="PN order to use for the phase",default=-1,type=int) 
parser.add_option("--waveform1-spin-order",help="Spin terms up to the given pN order are included",default=-1,type=int) 
parser.add_option("--waveform1-start-frequency",help="Starting frequency for injections",type=float) 

parser.add_option("--waveform2-approximant",help="waveform2 Approximant Name: " + str(aprs), choices = aprs)
parser.add_option("--waveform2-phase-order",help="PN order to use for the phase",default=-1,type=int) 
parser.add_option("--waveform2-amplitude-order",help="PN order to use for the phase",default=-1,type=int) 
parser.add_option("--waveform2-spin-order",help="Spin terms up to the given pN order are included",default=-1,type=int) 
parser.add_option("--waveform2-start-frequency",help="Starting frequency for waveform1s",type=float)  

#Filter Settings
parser.add_option('--filter-low-frequency-cutoff', metavar='FREQ', help='low frequency cutoff of matched filter', type=float)
parser.add_option("--filter-sample-rate",help="Filter Sample Rate [Hz]",type=float)
parser.add_option("--filter-waveform-length",help="Length of waveform for filtering, shoud be longer than all waveforms and include some padding",type=int)

parser.add_option("--cuda",action="store_true")            
(options, args) = parser.parse_args()   

if options.psd and options.asd_file:
    parser.error("PSD and asd-file options are mututally exclusive")


if options.cuda:
    ctx = CUDAScheme()
else:
    ctx = CPUScheme()

# Load in the waveform1 bank file
indoc = ligolw_utils.load_filename(options.bank_file, False, contenthandler=ContentHandler)
try :
    waveform_table = table.get_table(indoc, lsctables.SnglInspiralTable.tableName) 
except ValueError:
    waveform_table = table.get_table(indoc, lsctables.SimInspiralTable.tableName)

# open the output file where the max overlaps over the bank are stored 
fout = open(options.out_file, "w")
print "Writing matches to " + options.out_file

filter_N = options.filter_waveform_length * options.filter_sample_rate
filter_n = int(filter_N / 2) + 1
delta_f = float(options.filter_sample_rate) / filter_N

print("Number of Waveforms      : ",len(waveform_table))

print("Reading and Interpolating PSD")
# Load the asd file

if options.asd_file:
    psd = pycbc.psd.from_txt(options.asd_file, filter_n,  
                           delta_f, options.filter_low_frequency_cutoff)
    psd *= DYN_RANGE_FAC **2
    psd = FrequencySeries(psd,delta_f=psd.delta_f,dtype=float32)
elif options.psd:
    psd = pycbc.psd.from_string(options.psd, filter_n, delta_f, 
                           options.filter_low_frequency_cutoff) 
    psd *= DYN_RANGE_FAC **2
    psd = FrequencySeries(psd,delta_f=psd.delta_f,dtype=float32) 
else:
    psd = None

matches = []
overlaps = []
s1s = []
s2s = []
print("Calculating Overlaps")
with ctx:
    index = 0 
    # Calculate the overlaps
    for waveform_params in waveform_table:
        index += 1
        update_progress(index*100/len(waveform_table))

        try:
            htilde1 = get_waveform(options.waveform1_approximant, 
                                  options.waveform1_phase_order, 
                                  options.waveform1_amplitude_order,
                                  options.waveform1_spin_order, 
                                  waveform_params, 
                                  options.waveform1_start_frequency, 
                                  options.filter_sample_rate, 
                                  filter_N)
             
            htilde2 = get_waveform(options.waveform2_approximant, 
                                  options.waveform2_phase_order, 
                                  options.waveform2_amplitude_order,
                                  options.waveform2_spin_order, 
                                  waveform_params, 
                                  options.waveform2_start_frequency, 
                                  options.filter_sample_rate, 
                                  filter_N)

            m,i = match(htilde1, htilde2, psd=psd, 
                        low_frequency_cutoff=options.filter_low_frequency_cutoff) 
                        
                        
            o = overlap(htilde1, htilde2, psd=psd, 
                        low_frequency_cutoff=options.filter_low_frequency_cutoff)   
            
            s1 = sigma(htilde1, psd=psd, 
                        low_frequency_cutoff=options.filter_low_frequency_cutoff)   
            s2 = sigma(htilde2, psd=psd, 
                        low_frequency_cutoff=options.filter_low_frequency_cutoff)   
            matches.append(m)
            overlaps.append(o)
            s1s.append(s1)
            s2s.append(s2)
        except:
            print "Unable to generate waveforms"
            matches.append(-1)
            overlaps.append(-1)
            s1s.append(-1)
            s2s.append(-1)

#Output the overlaps to  a file
for m, o, s1, s2 in zip(matches, overlaps, s1s, s2s):
    match_str= "%5.5f %5.5f %5.5f %5.5f\n" % (m, o, s1, s2)
    fout.write(match_str)


back to top