Revision 29df1b564abf7a6ed5230a5e97a85da70c403c2a authored by Ian Harry on 15 June 2017, 14:00:24 UTC, committed by Duncan Brown on 15 June 2017, 14:00:24 UTC
1 parent 23fd945
Raw File
pycbc_single_template
#!/usr/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Calculate the SNR and CHISQ timeseries for either a chosen template, or 
a specific Nth loudest coincident event.
"""
import sys, logging, argparse, numpy, pycbc, h5py
from pycbc import vetoes, psd, waveform, strain, scheme, fft, filter
from pycbc.io import WaveformArray
from pycbc import events
from pycbc.filter import resample_to_delta_t
from pycbc.types import zeros, complex64
from pycbc.types import complex_same_precision_as
import pycbc.waveform.utils
import pycbc.version

def subtract_template(stilde, template, snr, trigger_time, flow):
    idx = int((trigger_time - snr.start_time) / snr.delta_t)
    if idx >= 0 and idx < len(snr):                
        sig = filter.sigma(template, psd=stilde.psd, low_frequency_cutoff=flow) 
        inverse = template * snr[idx] / sig
        dt = trigger_time - snr.start_time
        stilde -= pycbc.waveform.utils.apply_fseries_time_shift(inverse, dt)
    return stilde

def select_segments(fname, anal_name, data_name, ifo, time, pad_data):
    anal_segs = events.select_segments_by_definer(fname, anal_name, ifo)
    data_segs = events.select_segments_by_definer(fname, data_name, ifo)

    # Anal segs should be disjoint, so first find the seg containing time
    s = numpy.array([t[0] for t in anal_segs])
    e = numpy.array([t[1] for t in anal_segs])
    #ensure sorted
    sorting = s.argsort()    
    s = s[sorting]
    e = e[sorting]
    idx = numpy.searchsorted(s, time) - 1
    anal_time = (s[idx], e[idx])

    # Now need to find the corresponding data_seg. This could be complicated
    # as in edge cases the anal_time tuple could be completely contained within
    # *two* data blocks (think analysis chunk slightly longer than minimum).
    # We need to choose the *right* block to reproduce what the search did.

    s2 = numpy.array([t[0] for t in data_segs])
    e2 = numpy.array([t[1] for t in data_segs])
    lgc = (s2 < time) & (e2 > time)
    s2 = s2[lgc]
    e2 = e2[lgc]
    if len(s2) == 0:
        err_msg = "Cannot find a data segment within %s" %(str(time))
        raise ValueError(err_msg)
    if len(s2) == 1:
        data_time = (s2[0], e2[0])
    if len(s2) > 1:
        # The tricksy case. The corresponding data segment should have the
        # largest overlap with anal_time
        overlap = None
        for start, end in zip(s2, e2):
            if start + pad_data > anal_time[0] or end - pad_data < anal_time[1]:
                # The analysis time must lie within the data time, otherwise
                # this clearly is not the right data segment!
                continue

            curr_nonoverlap = anal_time[0] - start
            curr_nonoverlap += end - anal_time[1]
            if (overlap is None) or (curr_nonoverlap < overlap):
                overlap = curr_nonoverlap
                data_time = (start, end)

    return anal_time, data_time

parser = argparse.ArgumentParser(usage='',
    description="Single template gravitational-wave followup")
parser.add_argument('--version', action=pycbc.version.Version) 
parser.add_argument('--output-file', required=True)
parser.add_argument('--subtract-template', action='store_true')
parser.add_argument("-V", "--verbose", action="store_true", 
                  help="print extra debugging information", default=False )
parser.add_argument("--low-frequency-cutoff", type=float,
                  help="The low frequency cutoff to use for filtering (Hz)")
parser.add_argument("--chisq-bins", default="0", type=str, help=
                    "Number of frequency bins to use for power chisq.")
parser.add_argument("--minimum-chisq-bins", default=0, type=int, help=
                    "If the chisq bin formula fails, default to this number.")
parser.add_argument("--trigger-time", type=float, default=0,
                  help="The central time of the trigger to use.")
parser.add_argument("--use-params-of-closest-injection", action="store_true",
                  default=False,
                  help="If given, use the injection with end_time closest to "
                       "--trigger-time as the waveform for filtering. If "
                       "using this do not supply mass and spin options. "
                       "Using this requires trigger-time and injection-file "
                       "to be given.")
# add approximant arg
waveform.bank.add_approximant_arg(parser,
                  help="The name of the approximant to use for filtering. "
                      "Do not use if using --use-params-of-closest-injection.")
parser.add_argument("--mass1", type=float, 
                  help="The mass of the first component object. "
                      "Do not use if using --use-params-of-closest-injection.")
parser.add_argument("--mass2", type=float,
                  help="The mass of the second component object. "
                      "Do not use if using --use-params-of-closest-injection.")
parser.add_argument("--spin1z", type=float, default=0,
                  help="The aligned spin of the first component object. "
                      "Do not use if using --use-params-of-closest-injection.")
parser.add_argument("--spin2z", type=float, default=0,
                  help="The aligned pin of the second component object. "
                      "Do not use if using --use-params-of-closest-injection.")
parser.add_argument("--template-start-frequency", type=float, default=None,
                  help="If given, use this as a start frequency for "
                       "generating the template. If not given the "
                       "--low-frequency-cutoff is used.")
# Optional arguments for precessing templates
parser.add_argument("--spin1x", type=float, default=0,
                  help="The non-aligned spin of the first component object. "
                    "Default = 0.")
parser.add_argument("--spin2x", type=float, default=0,
                  help="The non-aligned spin of the second component object. "
                    "Default = 0.")
parser.add_argument("--spin1y", type=float, default=0,
                  help="The non-aligned spin of the first component object. "
                    "Default = 0.")
parser.add_argument("--spin2y", type=float, default=0,
                  help="The non-aligned spin of the second component object. "
                    "Default = 0.")
parser.add_argument("--inclination", type=float, default=0,
                  help="The inclination of the source w.r.t the observer. "
                    "Default = 0.")
parser.add_argument("--u-val", type=float, default=None,
                  help="The ratio between hplus and hcross to use in the "
                    "template, according to h(t) = hplus * u_val + hcross. "
                    "If not given only hplus is used.")
parser.add_argument("--window", type=float, 
                  help="Time to save around the trigger time (if given)")
parser.add_argument("--order", type=int,
                  help="The integer half-PN order at which to generate"
                       " the approximant. Default is -1 which indicates to use"
                       " approximant defined default.", default=-1, 
                       choices = numpy.arange(-1, 9, 1))
parser.add_argument("--taper-template", choices=["start","end","startend"],
                    help="For time-domain approximants, taper the start and/or"
                    " end of the waveform before FFTing.")

# These options can be used to identify start/end times
parser.add_argument("--inspiral-segments",
        help="XML file containing the inspiral analysis segments. "
             "Only used with the --statmap-file option")
parser.add_argument("--data-read-name",
        help="name of the segmentlist containing the data read in by each job "
             "from the inspiral segment file")
parser.add_argument("--data-analyzed-name",
        help="name of the segmentlist containing the data analysed by each job "
             "from the inspiral segment file")

# Add options groups
psd.insert_psd_option_group(parser)
strain.insert_strain_option_group(parser)
strain.StrainSegments.insert_segment_option_group(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)
opt = parser.parse_args()

f = h5py.File(opt.output_file, 'w')          
ifo = opt.channel_name[0:2]

# If we are choosing start/end times from XML file ############################
if opt.inspiral_segments:
    # Important for trig-start/end and data-start/end to match inspiral jobs.
    # Zero-padding also will not zero-pad unless explicitly told to in the
    # trig start/end times
    anal_seg, data_seg = select_segments(opt.inspiral_segments,
                                    opt.data_analyzed_name, opt.data_read_name,
                                    ifo, opt.trigger_time, opt.pad_data)
    opt.trig_start_time = anal_seg[0]
    opt.trig_end_time = anal_seg[1]
    opt.gps_start_time = data_seg[0] + opt.pad_data
    opt.gps_end_time = data_seg[1] - opt.pad_data
    f.attrs['event_time'] = opt.trigger_time

###############################################################################

# Check that the values returned for the options make sense
psd.verify_psd_options(opt, parser)
strain.verify_strain_options(opt, parser)
strain.StrainSegments.verify_segment_options(opt, parser)
scheme.verify_processing_options(opt, parser)
fft.verify_fft_options(opt,parser)
pycbc.init_logging(opt.verbose)

ctx = scheme.from_cli(opt)
gwstrain = strain.from_cli(opt, pycbc.DYN_RANGE_FAC)
strain_segments = strain.StrainSegments.from_cli(opt, gwstrain)

if not opt.use_params_of_closest_injection:
    row = WaveformArray.from_kwargs(
            mass1=opt.mass1,
            mass2=opt.mass2,
            spin1x=opt.spin1x,
            spin1y=opt.spin1y,
            spin1z=opt.spin1z,
            spin2x=opt.spin2x,
            spin2y=opt.spin2y,
            spin2z=opt.spin2z)

with ctx:
    fft.from_cli(opt)
    flow = opt.low_frequency_cutoff
    # don't want a template f_lower that is None or 0.0
    template_flow = opt.template_start_frequency or flow
    flen = strain_segments.freq_len
    delta_f = strain_segments.delta_f

    logging.info("Making frequency-domain data segments")
    segments = strain_segments.fourier_segments()
    
    logging.info("Calculating the PSDs")
    psd.associate_psds_to_segments(opt, segments, gwstrain, flen, delta_f,
                             flow, dyn_range_factor=pycbc.DYN_RANGE_FAC,
                             precision='single')

    logging.info("Making template: %s" % opt.approximant)
    if opt.use_params_of_closest_injection:
        if not hasattr(gwstrain, 'injections') or not opt.trigger_time:
            err_msg = "To use --use-params-of-closest-injection you must "
            err_msg += "be making injections with an injection file and using "
            err_msg += "the --trigger-time option."
            raise ValueError(err_msg)
        logging.info("Making template directly from injection.")
        inj_set = gwstrain.injections
        end_times = numpy.array(inj_set.end_times())
        inj_idx = abs(end_times - opt.trigger_time).argmin()
        inj = inj_set.table[inj_idx]
        # Use injection values for things like choosing number of chisq bins
        row = WaveformArray.from_kwargs(
            mass1=inj.mass1,
            mass2=inj.mass2,
            spin1x=inj.spin1x,
            spin1y=inj.spin1y,
            spin1z=inj.spin1z,
            spin2x=inj.spin2x,
            spin2y=inj.spin2y,
            spin2z=inj.spin2z)
        # FIXME: Don't like hardcoded 16384 here
        # NOTE: f_lower is set in strain.py as inj.f_lower, but this is a
        #       little unclear to see and caused me some problems! Stating
        #       this clearly here so no-one makes the same mistake as me.
        td_template = inj_set.make_strain_from_inj_object(inj, 1./16384.,
                                    opt.channel_name[0:2], f_lower=inj.f_lower,
                                    distance_scale=opt.injection_scale_factor)
        td_template = resample_to_delta_t(td_template, gwstrain.delta_t,
                                          method='ldas')
        td_template = td_template * pycbc.DYN_RANGE_FAC
        td_template._epoch -= inj.get_end(site=opt.channel_name[0])
        approximant = inj.waveform
        template = waveform.td_waveform_to_fd_waveform(td_template, length=flen)
        template = template.astype(complex_same_precision_as(segments[0]))
    else:
        approximant = waveform.bank.parse_approximant_arg(opt.approximant, row)[0] 
        logging.info("Making template: %s" % opt.approximant)
        if opt.u_val is None:
            template = waveform.get_waveform_filter(
                                    zeros(flen, dtype=complex64),
                                    approximant=approximant,
                                    template=row[0],
                                    inclination=opt.inclination,
                                    taper=opt.taper_template, 
                                    f_lower=template_flow, delta_f=delta_f,
                                    delta_t=gwstrain.delta_t,
                                    distance = 1.0/pycbc.DYN_RANGE_FAC)
        else:
            tp, tc = waveform.get_two_pol_waveform_filter(
                                    zeros(flen, dtype=complex64),
                                    zeros(flen, dtype=complex64), row[0],
                                    approximant=approximant,
                                    inclination=opt.inclination,
                                    taper=opt.taper_template,
                                    f_lower=template_flow, delta_f=delta_f,
                                    delta_t=gwstrain.delta_t)
            template = tc.multiply_and_add(tp, opt.u_val)

    # FIXME: should probably switch to something like what is done for parsing
    # the approximant for chisq bins at some point
    class t(object):
        pass
    parse_row = t()
    parse_row.params = t()
    for param in row.fieldnames:
        setattr(parse_row.params, param, row[param][0])

    chisq_bins_float = vetoes.SingleDetPowerChisq.parse_option(parse_row,
        opt.chisq_bins)
    if numpy.isnan(chisq_bins_float) or (int(chisq_bins_float) < 
                                    opt.minimum_chisq_bins): 
        if opt.minimum_chisq_bins:
            chisq_bins = opt.minimum_chisq_bins
            logging.warning( "Number of chisq bins is less than minimum or is NaN."
                            + (" Using %d bins." % opt.minimum_chisq_bins) )
        else:
            raise ValueError(
                "Chisq bins is NaN or negative and no minimum is set.")
    else:
        chisq_bins = int(chisq_bins_float)

    f['template'] = template.numpy()                  
    snrs, chisqs = [], []  
    raw_bins = [[] for b in range(chisq_bins)]
    
    if opt.trig_start_time:
        start_time = opt.trig_start_time
        end_time = opt.trig_end_time
    else:
        start_time = opt.gps_start_time + opt.segment_start_pad
        end_time = opt.gps_end_time - opt.segment_end_pad

    if opt.window:
        start_time_wind = opt.trigger_time - opt.window
        if start_time_wind > start_time:
            start_time = start_time_wind
        end_time_wind = opt.trigger_time + opt.window
        if end_time_wind < end_time:
            end_time = end_time_wind
                        
    for s_num, stilde in enumerate(segments):    
        start = stilde.epoch + stilde.analyze.start / float(opt.sample_rate)
        end = stilde.epoch + stilde.analyze.stop / float(opt.sample_rate)
        
        if end < start_time:
            continue
            
        if start > end_time:
            break
    
        logging.info("Filtering segment %s" % s_num)
        snr, corr, norm = filter.matched_filter_core(template, stilde, 
                                    psd=stilde.psd,
                                    low_frequency_cutoff=flow)
        snr *= norm 
                                    
        if opt.subtract_template:
            stilde = subtract_template(stilde, template,
                                       snr, opt.trigger_time, flow)
            snr, corr, norm = filter.matched_filter_core(template, stilde, 
                                psd=stilde.psd,
                                low_frequency_cutoff=flow)
        
        logging.info("calculating chisq")
        chisq, raw_bin = vetoes.power_chisq(template, stilde, chisq_bins, stilde.psd, 
                                    low_frequency_cutoff=flow, 
                                    return_bins=True)
        chisq /= chisq_bins * 2 - 2
     
        snrs.append(snr[stilde.analyze])
        chisqs.append(chisq[stilde.analyze])
        
        for i in range(chisq_bins):
            raw_bins[i].append(raw_bin[i][stilde.analyze].numpy())
    
    sidx = int((start_time - snrs[0].start_time) / snr.delta_t)
    eidx = int((end_time - start_time) / snr.delta_t) + sidx

    if sidx < 0:
        err_msg = "Ian has probably broken single_template again. Please email "
        err_msg += "with the command line that is raising this error and "
        err_msg += "shout at him. Beer may speed up the fixing process."
        raise ValueError(err_msg)

    for i in range(chisq_bins):
        key = 'chisq_bins/%s' % i
        f[key] = numpy.concatenate(raw_bins[i])[sidx:eidx]
        f[key].attrs['start_time'] = start_time
        f[key].attrs['delta_t'] = snr.delta_t

    f['chisq_boundaries'] = numpy.array(vetoes.power_chisq_bins(template, chisq_bins, stilde.psd,
                                    low_frequency_cutoff=flow)) * template.delta_f 
        
    f['snr'] = numpy.concatenate([snr.numpy() for snr in snrs])[sidx:eidx]
    f['snr'].attrs['start_time'] = start_time
    f['snr'].attrs['delta_t'] = snr.delta_t
    
    f['chisq'] = numpy.concatenate([chisq.numpy() for chisq in chisqs])[sidx:eidx]
    f['chisq'].attrs['start_time'] = start_time
    f['chisq'].attrs['delta_t'] = snr.delta_t 
    f.attrs['approximant'] = approximant
    f.attrs['ifo'] = ifo
    f.attrs['command_line'] = ' '.join(sys.argv)
    
logging.info("Finished")

back to top