swh:1:snp:3a699297f000109a1bc833f294a54171df990207
Raw File
Tip revision: eba2e1f492dc26e1d409e36f966ebac8a5008ad4 authored by Gareth S Cabourn Davies on 07 September 2021, 15:54:14 UTC
error in that page_foreground was outputting integer times (#3790)
Tip revision: eba2e1f
pycbc_inspiral
#!/usr/bin/env python

# Copyright (C) 2014 Alex Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import sys, os
import logging, argparse, numpy, itertools
from six.moves import range
import pycbc
import pycbc.version
from pycbc import vetoes, psd, waveform, strain, scheme, fft, DYN_RANGE_FAC, events
from pycbc.vetoes.sgchisq import SingleDetSGChisq
from pycbc.filter import MatchedFilterControl, make_frequency_series, qtransform
from pycbc.types import TimeSeries, FrequencySeries, zeros, float32, complex64
import pycbc.version
import pycbc.opt
import pycbc.inject
import time

last_progress_update = -1.0

def update_progress(p,u,n):
    """ updates a file 'progress.txt' with a value 0 .. 1.0 when enough (filtering) progress was made
    """
    global last_progress_update
    if p > last_progress_update + u:
        f = open(n,"w")
        if f:
            f.write("%.4f" % p)
            f.close()
        last_progress_update = p

tstart = time.time()

parser = argparse.ArgumentParser(usage='',
    description="Find single detector gravitational-wave triggers.")

parser.add_argument('--version', action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                  help="print extra debugging information", default=False )
parser.add_argument("--update-progress",
                  help="updates a file 'progress.txt' with a value 0 .. 1.0 when this amount of (filtering) progress was made",
                  type=float, default=0)
parser.add_argument("--update-progress-file",
                  help="name of the file to write the amount of (filtering) progress to",
                  type=str, default="progress.txt")
parser.add_argument("--output", type=str, help="FIXME: ADD")
parser.add_argument("--bank-file", type=str, help="FIXME: ADD")
parser.add_argument("--snr-threshold",
                  help="SNR threshold for trigger generation", type=float)
parser.add_argument("--newsnr-threshold", type=float, metavar='THRESHOLD',
                    help="Cut triggers with NewSNR less than THRESHOLD")
parser.add_argument("--low-frequency-cutoff", type=float,
                  help="The low frequency cutoff to use for filtering (Hz)")
parser.add_argument("--enable-bank-start-frequency", action='store_true',
                  help="Read the starting frequency of template waveforms"
                       " from the template bank.")
parser.add_argument("--max-template-length", type=float,
                  help="The maximum length of a template is seconds. The "
                       "starting frequency of the template is modified to "
                       "ensure the proper length")
parser.add_argument("--enable-q-transform", action='store_true',
                  help="compute the q-transform for each segment of a "
                       "given analysis run. (default = False)")
# add approximant arg
pycbc.waveform.bank.add_approximant_arg(parser)
parser.add_argument("--order", type=int,
                  help="The integer half-PN order at which to generate"
                       " the approximant. Default is -1 which indicates to use"
                       " approximant defined default.", default=-1,
                       choices = numpy.arange(-1, 9, 1))
taper_choices = ["start","end","startend"]
parser.add_argument("--taper-template", choices=taper_choices,
                    help="For time-domain approximants, taper the start and/or"
                    " end of the waveform before FFTing.")
parser.add_argument("--cluster-method", choices=["template", "window"],
                    help="FIXME: ADD")
parser.add_argument("--cluster-function", choices=["findchirp", "symmetric"],
                    help="How to cluster together triggers within a window. "
                    "'findchirp' uses a forward sliding window; 'symmetric' "
                    "will compare each window to the one before and after, keeping "
                    "only a local maximum.", default="findchirp")
parser.add_argument("--cluster-window", type=float, default = -1,
                    help="Length of clustering window in seconds."
                    " Set to 0 to disable clustering.")
parser.add_argument("--bank-veto-bank-file", type=str, help="FIXME: ADD")
parser.add_argument("--chisq-snr-threshold", type=float,
                    help="Minimum SNR to calculate the power chisq")
parser.add_argument("--chisq-bins", default=0, help=
                    "Number of frequency bins to use for power chisq. Specify"
                    " an integer for a constant number of bins, or a function "
                    "of template attributes.  Math functions are "
                    "allowed, ex. "
                    "'10./math.sqrt((params.mass1+params.mass2)/100.)'. "
                    "Non-integer values will be rounded down.")
parser.add_argument("--chisq-threshold", type=float, default=0,
                    help="FIXME: ADD")
parser.add_argument("--chisq-delta", type=float, default=0, help="FIXME: ADD")
parser.add_argument("--autochi-number-points", type=int, default=0,
                    help="The number of points to use, in both directions if"
                         "doing a two-sided auto-chisq, to calculate the"
                         "auto-chisq statistic.")
parser.add_argument("--autochi-stride", type=int, default=0,
                    help="The gap, in sample points, between the points at"
                         "which to calculate auto-chisq.")
parser.add_argument("--autochi-two-phase", action="store_true",
                    default=False,
                    help="If given auto-chisq will be calculated by testing "
                         "against both phases of the SNR time-series. "
                         "If not given, only the phase matching the trigger "
                         "will be used.")
parser.add_argument("--autochi-onesided", action='store', default=None,
                    choices=['left','right'],
                    help="Decide whether to calculate auto-chisq using"
                         "points on both sides of the trigger or only on one"
                         "side. If not given points on both sides will be"
                         "used. If given, with either 'left' or 'right',"
                         "only points on that side (right = forward in time,"
                         "left = back in time) will be used.")
parser.add_argument("--autochi-reverse-template", action="store_true",
                    default=False,
                    help="If given, time-reverse the template before"
                         "calculating the auto-chisq statistic. This will"
                         "come at additional computational cost as the SNR"
                         "time-series will need recomputing for the time-"
                         "reversed template.")
parser.add_argument("--autochi-max-valued", action="store_true",
                    default=False,
                    help="If given, store only the maximum value of the auto-"
                         "chisq over all points tested. A disadvantage of this "
                         "is that the mean value will not be known "
                         "analytically.")
parser.add_argument("--autochi-max-valued-dof", action="store", metavar="INT",
                    type=int,
                    help="If using --autochi-max-valued this value denotes "
                         "the pre-calculated mean value that will be stored "
                         "as the auto-chisq degrees-of-freedom value.")
parser.add_argument("--downsample-factor", type=int,
                    help="Factor that determines the interval between the "
                         "initial SNR sampling. If not set (or 1) no sparse sample "
                         "is created, and the standard full SNR is calculated.", default=1)
parser.add_argument("--upsample-threshold", type=float,
                    help="The fraction of the SNR threshold to check the sparse SNR sample.")
parser.add_argument("--upsample-method", choices=["pruned_fft"],
                    help="The method to find the SNR points between the sparse SNR sample.",
                    default='pruned_fft')
parser.add_argument("--user-tag", type=str, metavar="TAG", help="""
                    This is used to identify FULL_DATA jobs for
                    compatibility with pipedown post-processing.
                    Option will be removed when no longer needed.""")
parser.add_argument("--keep-loudest-log-chirp-window", type=float,
                    help="Keep loudest triggers within ln chirp mass window")
parser.add_argument("--keep-loudest-interval", type=float,
                    help="Window in seconds to maximize triggers over bank")
parser.add_argument("--keep-loudest-num", type=int,
                    help="Number of triggers to keep from each maximization interval")
parser.add_argument("--keep-loudest-stat", default="newsnr",
                    choices=events.stat.sngl_statistic_dict.keys(),
                    help="Statistic used to determine loudest to keep")
parser.add_argument("--finalize-events-template-rate", default=None,
                    type=int, metavar="NUM TEMPLATES",
                    help="After NUM TEMPLATES perform the various clustering "
                         "and rejection tests that would be performed at the "
                         "end of this job. Default is to only do those things "
                         "at the end of the job. This can help control memory "
                         "usage if a lot of triggers that would be rejected "
                         "are being retained. A suggested value for this is "
                         "500, but a good number may depend on other settings "
                         "and your specific use-case.")
parser.add_argument("--gpu-callback-method", default='none')
parser.add_argument("--use-compressed-waveforms", action="store_true", default=False,
                    help='Use compressed waveforms from the bank file.')
parser.add_argument("--waveform-decompression-method", action='store', default=None,
                    help='Method to be used decompress waveforms from the bank file.')
parser.add_argument("--checkpoint-interval", type=int,
                    help="Save results to checkpoint file every X seconds. "
                         "Default is no checkpointing.")
parser.add_argument('--require-valid-checkpoint', default=False, action="store_true",
                    help="If the checkpoint file is invalid, raise an error. "
                         "Default is to ignore invalid checkpoint files and to "
                         "delete the broken file.")
parser.add_argument("--checkpoint-exit-maxtime", type=int,
                    help="Checkpoint and exit if X seconds of execution"
                         " time is exceeded. Default is no checkpointing.")
parser.add_argument("--checkpoint-exit-code", type=int, default=77,
                    help="Exit code returned if exiting after a checkpoint")

# Add options groups
psd.insert_psd_option_group(parser)
strain.insert_strain_option_group(parser)
strain.StrainSegments.insert_segment_option_group(parser)
scheme.insert_processing_option_group(parser)
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
pycbc.inject.insert_injfilterrejector_option_group(parser)
SingleDetSGChisq.insert_option_group(parser)
opt = parser.parse_args()

# Check that the values returned for the options make sense
psd.verify_psd_options(opt, parser)
strain.verify_strain_options(opt, parser)
strain.StrainSegments.verify_segment_options(opt, parser)
scheme.verify_processing_options(opt, parser)
fft.verify_fft_options(opt,parser)
pycbc.opt.verify_optimization_options(opt, parser)

pycbc.init_logging(opt.verbose)

fft.from_cli(opt)
inj_filter_rejector = pycbc.inject.InjFilterRejector.from_cli(opt)
ctx = scheme.from_cli(opt)

gwstrain = strain.from_cli(opt, dyn_range_fac=DYN_RANGE_FAC,
                           inj_filter_rejector=inj_filter_rejector)

strain_segments = strain.StrainSegments.from_cli(opt, gwstrain)


with ctx:
    # The following FFTW specific options needed to wait until
    # we were inside the scheme context.

    # Import system wisdom.
    if opt.fftw_import_system_wisdom:
        fft.fftw.import_sys_wisdom()

    # Read specified user-provided wisdom files
    if opt.fftw_input_float_wisdom_file is not None:
        fft.fftw.import_single_wisdom_from_filename(opt.fftw_input_float_wisdom_file)

    if opt.fftw_input_double_wisdom_file is not None:
        fft.fftw.import_double_wisdom_from_filename(opt.fftw_input_double_wisdom_file)

    flow = opt.low_frequency_cutoff
    flen = strain_segments.freq_len
    tlen = strain_segments.time_len
    delta_f = strain_segments.delta_f


    logging.info("Making frequency-domain data segments")
    segments = strain_segments.fourier_segments()
    psd.associate_psds_to_segments(opt, segments, gwstrain, flen, delta_f,
                  flow, dyn_range_factor=DYN_RANGE_FAC, precision='single')

    # storage for values and types to be passed to event manager
    out_types = {
        'time_index'     : int,
        'snr'            : complex64,
        'chisq'          : float32,
        'chisq_dof'      : int,
        'bank_chisq'     : float32,
        'bank_chisq_dof' : int,
        'cont_chisq'     : float32,
        'psd_var_val'    : float32,
                }
    out_types.update(SingleDetSGChisq.returns)
    out_vals = {key: None for key in out_types}
    names = sorted(out_vals.keys())

    if len(strain_segments.segment_slices) == 0:
        logging.info("--filter-inj-only specified and no injections in analysis time")
        event_mgr = events.EventManager(
              opt, names, [out_types[n] for n in names], psd=None,
              gating_info=gwstrain.gating_info)
        event_mgr.finalize_template_events()
        event_mgr.write_events(opt.output)
        logging.info("Finished")
        sys.exit(0)

    # FIXME: Maybe we should use the PSD corresponding to each trigger
    if opt.psdvar_segment is not None:
        logging.info("Calculating PSD variation")
        psd_var = pycbc.psd.calc_filt_psd_variation(gwstrain, opt.psdvar_segment,
                opt.psdvar_short_segment, opt.psdvar_long_segment, 
                opt.psdvar_psd_duration, opt.psdvar_psd_stride,
                opt.psd_estimation, opt.psdvar_low_freq, opt.psdvar_high_freq)

    if opt.enable_q_transform:
        logging.info("Performing q-transform on analysis segments")
        q_trans = qtransform.inspiral_qtransform_generator(segments)

    else:
        q_trans = {}

    if opt.checkpoint_interval:
        checkpoint_file = opt.output + '.checkpoint'

    checkpoint_exists = False
    tnum_start = 0
    if (opt.checkpoint_interval and os.path.isfile(checkpoint_file)):
        try:
            tnum_start, event_mgr = events.EventManager.restore_state(checkpoint_file)
            checkpoint_exists = True
        except Exception as e:
            if args.require_valid_checkpoint:
                logging.info("Failed to load checkpoint file")
                raise(e)

            logging.info('Failed to load checkpoint file, starting anew')
            logging.info(e)

    if not checkpoint_exists:
        event_mgr = events.EventManager(
            opt, names, [out_types[n] for n in names], psd=segments[0].psd,
            gating_info=gwstrain.gating_info, q_trans=q_trans)

    template_mem = zeros(tlen, dtype = complex64)
    cluster_window = int(opt.cluster_window * gwstrain.sample_rate)

    if opt.cluster_window == 0.0:
        use_cluster = False
    else:
        use_cluster = True

    if hasattr(ctx, "num_threads"):
            ncores = ctx.num_threads
    else:
            ncores = 1


    matched_filter = MatchedFilterControl(opt.low_frequency_cutoff, None,
                                   opt.snr_threshold, tlen, delta_f, complex64,
                                   segments, template_mem, use_cluster,
                                   downsample_factor=opt.downsample_factor,
                                   upsample_threshold=opt.upsample_threshold,
                                   upsample_method=opt.upsample_method,
                                   gpu_callback_method=opt.gpu_callback_method,
                                   cluster_function=opt.cluster_function)

    bank_chisq = vetoes.SingleDetBankVeto(opt.bank_veto_bank_file,
                                          flen, delta_f, flow, complex64,
                                          phase_order=opt.order,
                                          approximant=opt.approximant)

    power_chisq = vetoes.SingleDetPowerChisq(opt.chisq_bins, opt.chisq_snr_threshold)

    autochisq = vetoes.SingleDetAutoChisq(opt.autochi_stride,
                                 opt.autochi_number_points,
                                 onesided=opt.autochi_onesided,
                                 twophase=opt.autochi_two_phase,
                                 reverse_template=opt.autochi_reverse_template,
                                 take_maximum_value=opt.autochi_max_valued,
                                 maximal_value_dof=opt.autochi_max_valued_dof)

    logging.info("Overwhitening frequency-domain data segments")
    for seg in segments:
        seg /= seg.psd

    logging.info("Read in template bank")
    bank = waveform.FilterBank(opt.bank_file, flen, delta_f,
        low_frequency_cutoff=None if opt.enable_bank_start_frequency else flow,
        dtype=complex64, phase_order=opt.order,
        taper=opt.taper_template, approximant=opt.approximant,
        out=template_mem, max_template_length=opt.max_template_length,
        enable_compressed_waveforms=True if opt.use_compressed_waveforms else False,
        waveform_decompression_method=
        opt.waveform_decompression_method if opt.use_compressed_waveforms else None)

    sg_chisq = SingleDetSGChisq.from_cli(opt, bank, opt.chisq_bins)

    ntemplates = len(bank)
    nfilters = 0

    logging.info("Full template bank size: %s", ntemplates)
    bank.template_thinning(inj_filter_rejector)
    if not len(bank) == ntemplates:
        logging.info("Template bank size after thinning: %s", len(bank))

    tsetup = time.time() - tstart
    tcheckpoint = time.time()

    # Note: in the class-based approach used now, 'template' is not explicitly used
    # within the loop.  Rather, the iteration simply fills the memory specifed in
    # the 'template_mem' argument to MatchedFilterControl with the next template
    # from the bank.
    for t_num in range(len(bank) - tnum_start):
        t_num += tnum_start
        tmplt_generated = False

        for s_num, stilde in enumerate(segments):
            # Filter check checks the 'inj_filter_rejector' options to
            # determine whether
            # to filter this template/segment if injections are present.
            if not inj_filter_rejector.template_segment_checker(
                    bank, t_num, stilde, opt.gps_start_time):
                continue
            if not tmplt_generated:
                template = bank[t_num]
                event_mgr.new_template(tmplt=template.params,
                    sigmasq=template.sigmasq(segments[0].psd))
                tmplt_generated = True

            if opt.cluster_method == "window":
                cluster_window = int(opt.cluster_window * gwstrain.sample_rate)
            if opt.cluster_method == "template":
                cluster_window = \
                    int(template.chirp_length * gwstrain.sample_rate)

            if opt.update_progress:
                update_progress((t_num + (s_num / float(len(segments))) ) / len(bank),
                                opt.update_progress, opt.update_progress_file)
            logging.info("Filtering template %d/%d segment %d/%d" %
                         (t_num + 1, len(bank), s_num + 1, len(segments)))

            nfilters = nfilters + 1
            snr, norm, corr, idx, snrv = \
               matched_filter.matched_filter_and_cluster(s_num,
                                                         template.sigmasq(stilde.psd),
                                                         cluster_window,
                                                         epoch=stilde._epoch)

            if not len(idx):
                continue

            out_vals['bank_chisq'], out_vals['bank_chisq_dof'] = \
                  bank_chisq.values(template, stilde.psd, stilde, snrv, norm,
                                    idx+stilde.analyze.start)

            out_vals['chisq'], out_vals['chisq_dof'] = \
                  power_chisq.values(corr, snrv, norm, stilde.psd,
                                     idx+stilde.analyze.start, template)

            out_vals['sg_chisq'] = sg_chisq.values(stilde, template, stilde.psd,
                                          snrv, norm,
                                          out_vals['chisq'],
                                          out_vals['chisq_dof'],
                                          idx+stilde.analyze.start)

            out_vals['cont_chisq'] = \
                  autochisq.values(snr, idx+stilde.analyze.start, template,
                                   stilde.psd, norm, stilde=stilde,
                                   low_frequency_cutoff=flow)

            idx += stilde.cumulative_index

            out_vals['time_index'] = idx
            out_vals['snr'] = snrv * norm

            if opt.psdvar_short_segment is not None:
                out_vals['psd_var_val'] = \
                            pycbc.psd.find_trigger_value(psd_var,
                                          out_vals['time_index'],
                                          opt.gps_start_time, opt.sample_rate)

            event_mgr.add_template_events(names, [out_vals[n] for n in names])

        event_mgr.cluster_template_events("time_index", "snr", cluster_window)
        event_mgr.finalize_template_events()
        if opt.finalize_events_template_rate is not None and \
                not (t_num+1) % opt.finalize_events_template_rate:
            event_mgr.consolidate_events(opt, gwstrain=gwstrain)

        if opt.checkpoint_interval and \
            (time.time() - tcheckpoint > opt.checkpoint_interval):
            event_mgr.save_state(t_num, opt.output + '.checkpoint')
            tcheckpoint = time.time()

        if opt.checkpoint_exit_maxtime and \
            (time.time() - tstart > opt.checkpoint_exit_maxtime):
            event_mgr.save_state(t_num, opt.output + '.checkpoint')
            sys.exit(opt.checkpoint_exit_code)           

event_mgr.consolidate_events(opt, gwstrain=gwstrain)
event_mgr.finalize_events()
logging.info("Outputting %s triggers" % str(len(event_mgr.events)))

tstop = time.time()
run_time = tstop - tstart
event_mgr.save_performance(ncores, nfilters, ntemplates, run_time, tsetup)

logging.info("Writing out triggers")
event_mgr.write_events(opt.output)

if opt.fftw_output_float_wisdom_file:
    fft.fftw.export_single_wisdom_to_filename(opt.fftw_output_float_wisdom_file)

if opt.fftw_output_double_wisdom_file:
    fft.fftw.export_double_wisdom_to_filename(opt.fftw_output_double_wisdom_file)

logging.info("Finished")
back to top