Revision 0d1cfeba950fb9cba8e8bd9e757f3cf8f7b501ff authored by Duncan Brown on 25 September 2015, 21:21:32 UTC, committed by Duncan Brown on 25 September 2015, 21:21:32 UTC
Add gating files at the workflow level, allow them to be URLs
2 parent s d422567 + 8ec69c4
Raw File
pycbc_multi_inspiral
#!/usr/bin/env /usr/bin/python

# Copyright (C) 2014 Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import logging
from collections import defaultdict
import argparse
from pycbc import vetoes, psd, waveform, events, strain, scheme, fft, DYN_RANGE_FAC
from pycbc.filter import MatchedFilterControl
from pycbc.types import TimeSeries, zeros, float32, complex64
from pycbc.types import MultiDetOptionAction

parser = argparse.ArgumentParser(usage='',
    description="Find multiple detector gravitational-wave triggers.")
parser.add_argument("-V", "--verbose", action="store_true",
                  help="print extra debugging information", default=False )
parser.add_argument("--output", type=str)
parser.add_argument("--instruments", nargs="+", type=str, required=True,
                    help="List of instruments to analyze.")
parser.add_argument("--bank-file", type=str)
parser.add_argument("--snr-threshold",
                  help="SNR threshold for trigger generation", type=float)
parser.add_argument("--newsnr-threshold", type=float, metavar='THRESHOLD',
                    help="Cut triggers with NewSNR less than THRESHOLD")
parser.add_argument("--low-frequency-cutoff", type=float,
                  help="The low frequency cutoff to use for filtering (Hz)")

parser.add_argument("--approximant", type=str,
                  help="The name of the approximant to use for filtering")
parser.add_argument("--order", type=str,
                  help="The integer half-PN order at which to generate"
                       " the approximant.")
taper_choices = ["start","end","startend"]
parser.add_argument("--taper-template", choices=taper_choices,
                  help="For time-domain approximants, taper the start and/or "
                       "end of the waveform before FFTing.")
parser.add_argument("--cluster-method", choices=["template", "window"])
parser.add_argument("--cluster-window", type=float, default = -1,
                  help="Length of clustering window in seconds")

parser.add_argument("--bank-veto-bank-file", type=str)

parser.add_argument("--chisq-bins", default=0)
parser.add_argument("--chisq-threshold", type=float, default=0)
parser.add_argument("--chisq-delta", type=float, default=0)

parser.add_argument("--autochi-number-points", type=int, default=0)
parser.add_argument("--autochi-stride", type=int, default=0)
parser.add_argument("--autochi-onesided", action='store_true', default=False)
parser.add_argument("--downsample-factor", type=int, default=1,
                    help="Factor that determines the interval between the "
                         "initial SNR sampling. If not set (or 1) no sparse "
                         "sample is created, and the standard full SNR is "
                         "calculated.")
parser.add_argument("--upsample-threshold", type=float,
                    help="The fraction of the SNR threshold to check the "
                         "sparse SNR sample.")
parser.add_argument("--upsample-method", choices=["pruned_fft"],
                    default='pruned_fft',
                    help="The method to find the SNR points between the "
                         "sparse SNR sample.")
parser.add_argument("--user-tag", type=str, metavar="TAG", help="""
                    This is used to identify FULL_DATA jobs for 
                    compatibility with pipedown post-processing. 
                    Option will be removed when no longer needed.""")


# Add options groups
strain.insert_strain_option_group_multi_ifo(parser)
strain.StrainSegments.insert_segment_option_group_multi_ifo(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)

opt = parser.parse_args()

strain.verify_strain_options_multi_ifo(opt, parser, opt.instruments)
strain.StrainSegments.verify_segment_options_multi_ifo(opt, parser,
                                                       opt.instruments)
psd.verify_psd_options_multi_ifo(opt, parser, opt.instruments)
scheme.verify_processing_options(opt, parser)
fft.verify_fft_options(opt,parser)

if opt.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

ctx = scheme.from_cli(opt)

strain_dict = strain.from_cli_multi_ifos(opt, opt.instruments, 
                                         dyn_range_fac=DYN_RANGE_FAC)
strain_segments_dict = strain.StrainSegments.from_cli_multi_ifos(opt, 
                                                  strain_dict, opt.instruments)

with ctx:
    fft.from_cli(opt)

    # Set some often used variables for easy access
    flow = opt.low_frequency_cutoff
    flow_dict = defaultdict(lambda : flow)
    for count, ifo in enumerate(opt.instruments):
        if count == 0:
            sample_rate = strain_dict[ifo].sample_rate
            sample_rate_dict = defaultdict(lambda : sample_rate)
            flen = strain_segments_dict[ifo].freq_len
            flen_dict = defaultdict(lambda : flen)
            tlen = strain_segments_dict[ifo].time_len
            tlen_dict = defaultdict(lambda : tlen)
            delta_f = strain_segments_dict[ifo].delta_f
            delta_f_dict = defaultdict(lambda : delta_f)            
        else:
            try:
                assert(sample_rate == strain_dict[ifo].sample_rate)
                assert(flen == strain_segments_dict[ifo].freq_len)
                assert(tlen == strain_segments_dict[ifo].time_len)
                assert(delta_f == strain_segments_dict[ifo].delta_f)
            except:
                err_msg = "Sample rate, frequency length and time length "
                err_msg += "must all be consistent across ifos."
                raise ValueError(err_msg)

    logging.info("Making frequency-domain data segments")
    segments = {}
    for ifo in opt.instruments:
        segments[ifo] = strain_segments_dict[ifo].fourier_segments()
        del strain_segments_dict[ifo]

    logging.info("Computing noise PSD")
    psd = psd.from_cli_multi_ifos(opt, flen_dict, delta_f_dict, flow_dict,
                                  opt.instruments, strain_dict=strain_dict, 
                                  dyn_range_factor=DYN_RANGE_FAC)
    for ifo in opt.instruments:
        psd[ifo] = psd[ifo].astype(float32)
        del strain_dict[ifo]

    # Currently we are using the same matched-filter parameters for all ifos.
    # Therefore only one MatchedFilterControl needed. Maybe this can change if
    # needed. Segments is only used to get tlen etc. which is same for all
    # ifos, so just send the first ifo
    matched_filter = MatchedFilterControl(opt.low_frequency_cutoff, None,
                                   opt.snr_threshold, tlen, delta_f, complex64,
                                   downsample_factor=opt.downsample_factor,
                                   upsample_threshold=opt.upsample_threshold,
                                   upsample_method=opt.upsample_method)

    logging.info("Initializing signal-based vetoes.")
    # The existing SingleDetPowerChisq can calculate the single detector
    # chisq for multiple ifos, so just use that directly.
    power_chisq = vetoes.SingleDetPowerChisq(opt.chisq_bins)
    # The existing SingleDetBankVeto can calculate the single detector
    # bank veto for multiple ifos, so we just use it directly.
    bank_chisq = vetoes.SingleDetBankVeto(opt.bank_veto_bank_file,
                                         opt.approximant, flen, delta_f, flow,
                                         complex64, phase_order=opt.order)
    # Same here
    autochisq = vetoes.SingleDetAutoChisq(opt.autochi_stride,
                                         opt.autochi_number_points,
                                         onesided=opt.autochi_onesided)

    logging.info("Overwhitening frequency-domain data segments")
    for ifo in opt.instruments:
        for seg in segments[ifo]:
            seg /= psd[ifo]

    out_types = {
        'time_index' : int,
        'ifo'        : int, # IFO is stored as an int internally!
        'snr'        : complex64,
        'chisq'      : float32,
        'chisq_dof'  : int,
        'bank_chisq' : float32,
        'cont_chisq' : float32
                }
    out_vals = {
        'time_index' : None,
        'ifo'        : None,
        'snr'        : None,
        'chisq'      : None,
        'chisq_dof'  : None,
        'bank_chisq' : None,
        'cont_chisq' : None
               }
    names = sorted(out_vals.keys())

    event_mgr = events.EventManagerMultiDet(opt, opt.instruments, names,
                                        [out_types[n] for n in names], psd=psd)

    logging.info("Read in template bank")
    bank = waveform.FilterBank(opt.bank_file, opt.approximant,
                flen, delta_f, flow, complex64, 
                phase_order=opt.order, taper=opt.taper_template,
                out=zeros(tlen, dtype=complex64))

    for t_num, template in enumerate(bank):
        sigmasq = {}
        for ifo in opt.instruments:
            sigmasq[ifo] = template.sigmasq(psd[ifo])

        event_mgr.new_template(tmplt=template.params, sigmasq=sigmasq)
        for ifo in opt.instruments:
            if opt.cluster_method == "window":
                cluster_window = int(opt.cluster_window * sample_rate)
            elif opt.cluster_method == "template":
                cluster_window = int(template.chirp_length * sample_rate)

            for s_num, stilde in enumerate(segments[ifo]):
                logging.info("Filtering template %d/%d ifo %s segment %d/%d" % \
                             (t_num + 1, len(bank), ifo, s_num + 1,
                              len(segments[ifo])))       

                snr, norm, corr, idx, snrv = \
                        matched_filter.matched_filter_and_cluster(template,
                            template.sigmasq(psd[ifo]), stilde, cluster_window)

                if len(idx) == 0:
                    continue

                out_vals['bank_chisq'] = bank_chisq.values(template, psd[ifo],
                                  stilde, snrv, norm, idx+stilde.analyze.start)
                out_vals['chisq'], out_vals['chisq_dof'] =\
                                  power_chisq.values(corr, snr, snrv, norm,
                                                     psd[ifo],
                                                     idx+stilde.analyze.start,
                                                     template)

                snrv *= norm

                out_vals['cont_chisq'] = autochisq.values(snr, corr,
                                  idx+stilde.analyze.start, template, psd[ifo],
                                  norm, low_frequency_cutoff=flow)

                idx += stilde.cumulative_index

                out_vals['time_index'] = idx
                out_vals['snr'] = snrv
                # IFO is stored as an int
                out_vals['ifo'] = event_mgr.ifo_dict[ifo]

                event_mgr.add_template_events_to_ifo(ifo, names,
                                                  [out_vals[n] for n in names])

            event_mgr.cluster_template_events_single_ifo("time_index",
                                                    "snr", cluster_window, ifo)
        event_mgr.finalize_template_events()

logging.info("Found %s triggers" % str(len(event_mgr.events)))
if opt.chisq_threshold and opt.chisq_bins:
    logging.info("Removing triggers with poor chisq")
    event_mgr.chisq_threshold(opt.chisq_threshold, opt.chisq_bins,
                                   opt.chisq_delta)
    logging.info("%d remaining triggers" % len(event_mgr.events))

if opt.newsnr_threshold and opt.chisq_bins:
    logging.info("Removing triggers with NewSNR below threshold")
    event_mgr.newsnr_threshold(opt.newsnr_threshold)
    logging.info("%d remaining triggers" % len(event_mgr.events))

logging.info("Writing out triggers")
event_mgr.write_events(opt.output)

if opt.fftw_output_float_wisdom_file:
    fft.fftw.export_single_wisdom_to_filename(opt.fftw_output_float_wisdom_file)

if opt.fftw_output_double_wisdom_file:
    fft.fftw.export_double_wisdom_to_filename(opt.fftw_output_double_wisdom_file)

logging.info("Finished")
back to top