https://github.com/gwastro/pycbc
Raw File
Tip revision: 30b3be87e488e62975030b15ceeaf6fe4153b3ee authored by Gareth S Cabourn Davies on 17 October 2023, 12:13:26 UTC
Prepare for release v2.4.0 (#4531)
Tip revision: 30b3be8
pycbc_live_single_trigger_fits
#!/usr/bin/python

# Copyright 2020 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

"""Fit a background model to single-detector triggers from PyCBC Live.

See https://arxiv.org/abs/2008.07494 for a description of the method."""

import os
import sys
import argparse
import logging
import numpy as np
import h5py
import pycbc
from pycbc.bin_utils import IrregularBins
from pycbc.events import cuts, trigger_fits as trstats
from pycbc.io import DictArray
from pycbc.events import ranking
from pycbc.events.coinc import cluster_over_time


def duration_bins_from_cli(args):
    """Create the duration bins from CLI options.
    """
    if args.duration_bin_edges:
        # direct bin specification
        return np.array(args.duration_bin_edges)
    # calculate bins from min/max and number
    min_dur = args.duration_bin_start
    max_dur = args.duration_bin_end
    if args.duration_from_bank:
        # read min/max duration directly from the bank itself
        with h5py.File(args.duration_from_bank, 'r') as bank_file:
            temp_durs = bank_file['template_duration'][:]
        min_dur, max_dur = min(temp_durs), max(temp_durs)
    if args.duration_bin_spacing == 'log':
        return np.logspace(
            np.log10(min_dur),
            np.log10(max_dur),
            args.num_duration_bins + 1
        )
    if args.duration_bin_spacing == 'linear':
        return np.linspace(
            min_dur,
            max_dur,
            args.num_duration_bins + 1
        )
    raise RuntimeError("Invalid duration bin specification")


parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--ifos", nargs="+", required=True,
                    help="Which ifo are we fitting the triggers for? "
                         "Required")
parser.add_argument("--top-directory", metavar='PATH', required=True,
                    help="Directory containing trigger files, top directory, "
                         "will contain subdirectories for each day of data. "
                         "Required.")
parser.add_argument("--analysis-date", required=True,
                    help="Date of the analysis, format YYYY_MM_DD. Required")
parser.add_argument("--file-identifier", default="H1L1V1-Live",
                    help="String required in filename to be considered for "
                         "analysis. Default: 'H1L1V1-Live'.")
parser.add_argument("--fit-function", default="exponential",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit. "
                         "Choose from exponential, rayleigh or power. "
                         "Default: exponential")
parser.add_argument("--duration-bin-edges", nargs='+', type=float,
                    help="Durations to use for bin edges. "
                         "Use if specifying exact bin edges, "
                         "Not compatible with --duration-bin-start, "
                         "--duration-bin-end and --num-duration-bins")
parser.add_argument("--duration-bin-start", type=float,
                    help="Shortest duration to use for duration bins."
                         "Not compatible with --duration-bins, requires "
                         "--duration-bin-end and --num-duration-bins.")
parser.add_argument("--duration-bin-end", type=float,
                    help="Longest duration to use for duration bins.")
parser.add_argument("--duration-from-bank",
                    help="Path to the template bank file to get max/min "
                         "durations from.")
parser.add_argument("--num-duration-bins", type=int,
                    help="How many template duration bins to split the bank "
                         "into before fitting.")
parser.add_argument("--duration-bin-spacing",
                    choices=['linear','log'], default='log',
                    help="How to set spacing for bank split "
                         "if using --num-duration-bins and "
                         "--duration-bin-start + --duration-bin-end "
                         "or --duration-from-bank.")
parser.add_argument('--prune-loudest', type=int,
                    help="Maximum number of loudest trigger clusters to "
                         "remove from each bin.")
parser.add_argument("--prune-window", type=float,
                    help="Window (seconds) either side of the --prune-loudest "
                         "loudest triggers in each duration bin to remove.")
parser.add_argument("--prune-stat-threshold", type=float,
                    help="Minimum --sngl-ranking value to consider a "
                         "trigger for pruning.")
parser.add_argument("--fit-threshold", type=float, default=5,
                    help="Lower threshold used in fitting the triggers."
                         "Default 5.")
parser.add_argument("--cluster", action='store_true',
                    help="Only use maximum of the --sngl-ranking value "
                         "from each file.")
parser.add_argument("--output", required=True,
                    help="File in which to save the output trigger fit "
                         "parameters.")
parser.add_argument("--sngl-ranking", default="newsnr",
                    choices=ranking.sngls_ranking_function_dict.keys(),
                    help="The single-detector trigger ranking to use.")

cuts.insert_cuts_option_group(parser)

args = parser.parse_args()

# Check input options

# Pruning options are mutually required or not needed
prune_options = [args.prune_loudest, args.prune_window,
                 args.prune_stat_threshold]

if any(prune_options) and not all(prune_options):
    parser.error("Require all or none of --prune-loudest, "
                 "--prune-window and --prune-stat-threshold")

# Check the bin options
if args.duration_bin_edges:
    if (args.duration_bin_start or args.duration_bin_end or
        args.duration_from_bank or args.num_duration_bins):
        parser.error("Cannot use --duration-bin-edges with "
                     "--duration-bin-start, --duration-bin-end, "
                     "--duration-from-bank or --num-duration-bins.")
else:
    if not args.num_duration_bins:
        parser.error("--num-duration-bins must be set if not using "
                     "--duration-bin-edges.")
    if not ((args.duration_bin_start and args.duration_bin_end) or
            args.duration_from_bank):
        parser.error("--duration-bin-start & --duration-bin-end or "
                     "--duration-from-bank must be set if not using "
                     "--duration-bin-edges.")
if args.duration_bin_end and \
        args.duration_bin_end <= args.duration_bin_start:
    parser.error("--duration-bin-end must be greater than "
                 "--duration-bin-start, got "
                 f"{args.duration_bin_end} and {args.duration_bin_start}")

pycbc.init_logging(args.verbose)

duration_bin_edges = duration_bins_from_cli(args)
logging.info("Duration bin edges: %s", duration_bin_edges)

logging.info("Finding files")

files = [f for f in os.listdir(os.path.join(args.top_directory, args.analysis_date))
         if args.file_identifier in f]

logging.info("%s files found", len(files))

# Add template duration cuts according to the bin inputs
args.template_cuts = args.template_cuts or []
args.template_cuts.append(f"template_duration:{min(duration_bin_edges)}:lower")
args.template_cuts.append(f"template_duration:{max(duration_bin_edges)}:upper_inc")

# Efficiency saving: add SNR cut before any others as sngl_ranking can
# only be less than SNR.
args.trigger_cuts = args.trigger_cuts or []
args.trigger_cuts.insert(0, f"snr:{args.fit_threshold}:lower_inc")

# Cut triggers with sngl-ranking below threshold
args.trigger_cuts.append(f"{args.sngl_ranking}:{args.fit_threshold}:lower_inc")

logging.info("Setting up the cut dictionaries")
trigger_cut_dict, template_cut_dict = cuts.ingest_cuts_option_group(args)

logging.info("Setting up duration bins")
tbins = IrregularBins(duration_bin_edges)

# Also calculate live time so that this fitting can be used in rate estimation
# Live time is not immediately obvious - get an approximation with 8 second
# granularity by adding 8 seconds per 'valid' file

live_time = {ifo: 0 for ifo in args.ifos}

logging.info("Getting events which meet criteria")

# Loop through files - add events which meet the immediately gettable
# criteria
date_directory = os.path.join(args.top_directory, args.analysis_date)

if not os.path.exists(date_directory):
    raise FileNotFoundError(f"The directory {date_directory} does not exist.")

files = [f for f in os.listdir(date_directory)
         if args.file_identifier in f and f.endswith('hdf')]

events = {}

for counter, filename in enumerate(files):
    if counter and counter % 1000 == 0:
        logging.info("Processed %d/%d files", counter, len(files))
        for ifo in args.ifos:
            if ifo not in events:
                # In case of no triggers for an extended period
                logging.info("%s: No data", ifo)
            else:
                logging.info("%s: %d triggers in %.0f s", ifo,
                             events[ifo].data['snr'].size, live_time[ifo])

    f = os.path.join(date_directory, filename)
    # If there is an IOerror with the file, don't fail, just carry on
    try:
        h5py.File(f, 'r')
    except IOError:
        logging.warn('IOError with file ' + f)
        continue

    # Triggers for this file
    triggers = {}
    with h5py.File(f, 'r') as fin:
        # Open the file: does it have the ifo group and snr dataset?
        for ifo in args.ifos:
            if not (ifo in fin and 'snr' in fin[ifo]):
                continue

            # Eventual FIX ME: live output files should (soon) have the live time
            # added, but for now, extract from the filename
            # Format of the filename is to have the live time as a dash,
            # followed by '.hdf' at the end of the filename
            lt = int(f.split('-')[-1][:-4])
            live_time[ifo] += lt

            n_triggers = fin[ifo]['snr'].size
            # Skip if there are no triggers
            if not n_triggers:
                continue

            # Read trigger value datasets from file
            # Get all datasets with the same size as the trigger SNRs, 
            # except for edge cases where the number of loudest, gates etc.
            # happens to be the same as the trigger count
            triggers[ifo] = {k: fin[ifo][k][:] for k in fin[ifo].keys()
                             if k not in ('loudest', 'stat', 'gates', 'psd')
                             and fin[ifo][k].size == n_triggers}

            # The stored chisq is actually reduced chisq, so hack the
            # chisq_dof dataset to use the standard conversions.
            # chisq_dof of 1.5 gives the right number (2 * 1.5 - 2 = 1)
            triggers[ifo]['chisq_dof'] = \
                1.5 * np.ones_like(triggers[ifo]['snr'])


    for ifo, trigs_ifo in triggers.items():

        # Apply the cuts to triggers
        keep_idx = cuts.apply_trigger_cuts(trigs_ifo, trigger_cut_dict)

        # triggers contains the datasets that we want to use for
        # the template cuts, so here it can be used as the template bank
        keep_idx = cuts.apply_template_cuts(trigs_ifo, template_cut_dict,
                                            template_ids=keep_idx)

        # Skip if no triggers survive the cuts
        if not keep_idx.size:
            continue

        # Apply the cuts
        triggers_cut = {k: trigs_ifo[k][keep_idx]
                        for k in trigs_ifo.keys()}

        # Calculate the sngl_ranking values
        sngls_value = ranking.get_sngls_ranking_from_trigs(
                          triggers_cut, args.sngl_ranking)

        triggers_cut[args.sngl_ranking] = sngls_value

        triggers_da = DictArray(data=triggers_cut)

        # If we are clustering, take the max sngl_ranking value
        if args.cluster:
            max_idx = sngls_value.argmax()
            # Make sure that the DictArray has array data, not float
            triggers_da = triggers_da.select([max_idx])

        if ifo in events:  # DictArray already exists for the ifo
            events[ifo] += triggers_da
        else:  # Set up a new dictionary entry
            events[ifo] = triggers_da

logging.info("All events processed")

logging.info("Number of events which meet all criteria:")
for ifo in args.ifos:
    if ifo not in events:
        logging.info("%s: No data", ifo)
    else:
        logging.info("%s: %d in %.2fs",
                     ifo, len(events[ifo]), live_time[ifo])

logging.info('Sorting events into template duration bins')

# Set up bins and prune loud events in each bin
n_bins = duration_bin_edges.size - 1
alphas = {i: np.zeros(n_bins, dtype=np.float32) for i in args.ifos}
counts = {i: np.zeros(n_bins, dtype=np.float32) for i in args.ifos}
event_bins = {}
times_to_prune = {ifo: [] for ifo in args.ifos}

for ifo in events:
    # Sort the events into their bins
    event_bins[ifo] = np.array([tbins[d]
                               for d in events[ifo].data['template_duration']])

    if args.prune_loudest:
        for bin_num in range(n_bins):
            inbin = event_bins[ifo] == bin_num

            binned_events = events[ifo].data[args.sngl_ranking][inbin]
            binned_event_times = events[ifo].data['end_time'][inbin]

            # Cluster triggers in time with the pruning window to ensure
            # that clusters are independent
            cidx = cluster_over_time(binned_events, binned_event_times,
                                     args.prune_window)

            # Find clusters at/above the statistic threshold
            above_stat_min = binned_events[cidx] >= args.prune_stat_threshold
            cidx = cidx[above_stat_min]

            if args.prune_loudest > cidx.size:
                # There are fewer clusters than the number specified,
                # so prune them all
                times_to_prune[ifo] += list(binned_event_times[cidx])
                continue

            # Find the loudest of the triggers in this bin
            argloudest = np.argsort(binned_events[cidx])[-args.prune_loudest:]
            times_to_prune[ifo] += list(binned_event_times[cidx][argloudest])

n_pruned = {ifo: [] for ifo in args.ifos}
pruned_trigger_times = {}
if args.prune_loudest:
    logging.info("Pruning triggers %.2fs either side of the loudest %d "
                 "triggers in each bin if %s > %.2f", args.prune_window,
                 args.prune_loudest, args.sngl_ranking,
                 args.prune_stat_threshold)
    for ifo in events:
        times = events[ifo].data['end_time'][:]
        outwith_window = np.ones_like(times, dtype=bool)
        for t in times_to_prune[ifo]:
            outwith_window &= abs(times - t) > args.prune_window
            # Need to make an (ever-so-small) correction to the live time
            live_time[ifo] -= 2 * args.prune_window

        # Save the pruned events for reporting
        within_window = np.logical_not(outwith_window)
        pruned_trigger_bins = event_bins[ifo][within_window]
        pruned_trigger_times[ifo] = times[within_window]

        # Remove pruned events from the arrays we will fit
        events[ifo] = events[ifo].select(outwith_window)
        event_bins[ifo] = event_bins[ifo][outwith_window]

        # Report the number of pruned triggers in each bin
        for bin_num in range(n_bins):
            pruned_inbin = pruned_trigger_bins == bin_num
            n_pruned_thisbin = np.count_nonzero(pruned_inbin)
            n_pruned[ifo].append(n_pruned_thisbin)
            logging.info("Pruned %d triggers from %s bin %d",
                         n_pruned_thisbin, ifo, bin_num)

# Do the fitting for each bin
for ifo in events:
    for bin_num in range(n_bins):

        inbin = event_bins[ifo] == bin_num

        if not np.count_nonzero(inbin):
            # No triggers, alpha and count are -1
            counts[ifo][bin_num] = -1
            alphas[ifo][bin_num] = -1
            continue

        counts[ifo][bin_num] = np.count_nonzero(inbin)
        alphas[ifo][bin_num], _ = trstats.fit_above_thresh(
            args.fit_function,
            events[ifo].data[args.sngl_ranking][inbin],
            args.fit_threshold
        )

logging.info("Writing results")
with h5py.File(args.output, 'w') as fout:
    for ifo in events:
        fout_ifo = fout.create_group(ifo)
        # Save the triggers we have used for the fits
        fout_ifo_trigs = fout_ifo.create_group('triggers')
        for key in events[ifo].data:
            fout_ifo_trigs[key] = events[ifo].data[key]
        if ifo in pruned_trigger_times:
            fout_ifo['pruned_trigger_times'] = pruned_trigger_times[ifo]

        fout_ifo['fit_coeff'] = alphas[ifo]
        fout_ifo['counts'] = counts[ifo]
        fout_ifo.attrs['live_time'] = live_time[ifo]
        fout_ifo.attrs['pruned_times'] = times_to_prune[ifo]
        fout_ifo.attrs['n_pruned'] = n_pruned[ifo]

    fout['bins_upper'] = tbins.upper()
    fout['bins_lower'] = tbins.lower()

    fout.attrs['analysis_date'] = args.analysis_date
    fout.attrs['input'] = sys.argv
    fout.attrs['cuts'] = args.template_cuts + args.trigger_cuts
    fout.attrs['fit_function'] = args.fit_function
    fout.attrs['fit_threshold'] = args.fit_threshold
    fout.attrs['sngl_ranking'] = args.sngl_ranking

logging.info("Done")
back to top