Revision 2a41f2768746c60e33b0d71c8b1a12d2fe50095f authored by Tito Dal Canton on 17 May 2023, 14:39:05 UTC, committed by GitHub on 17 May 2023, 14:39:05 UTC
* Fix edge case in check for horizon distance

* Suggestion by Gareth; improve docstring & comments
1 parent 2be945c
Raw File
pycbc_pygrb_page_tables
#!/usr/bin/env python

# Copyright (C) 2021 Francesco Pannarale & Cameron Mills
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Processes PyGRB triggers and injections to create html results tables.
"""

# =============================================================================
# Preamble
# =============================================================================
import sys
import os
import logging
import numpy as np
from scipy import stats
import h5py
import pycbc.version
from pycbc.detector import Detector
from pycbc import init_logging
import pycbc.results
from pycbc.results import pygrb_postprocessing_utils as ppu
try:
    from glue.ligolw import lsctables
except ImportError:
    pass

__author__ = "Francesco Pannarale <francesco.pannarale@ligo.org>"
__version__ = pycbc.version.git_verbose_msg
__date__ = pycbc.version.date
__program__ = "pycbc_pygrb_page_tables"


# =============================================================================
# Functions
# =============================================================================
def get_column(table, column_name):
    """Wrapper for glue.ligolw.lsctables.MultiInspiralTable.get_column
    method. Easier for h5py switch. Note: Must still replace
    .get_sigmasqs() and .get_snglsnr() methods throughout the scripts.

    Parameters
    ----------
    table: glue.ligolw.lsctables.MultiInspiralTable
    column_name: str

    Returns
    -------
    column: numpy.array
    """

    # Annoyingly only first letter of ifo for eff_dist in tables.
    if column_name.startswith("eff_dist_") and column_name.endswith(("1", "2")):
        column_name = column_name[:-1]
    elif column_name == "null_snr":
        return table.get_null_snr()

    return table.get_column(column_name)


def dict_to_ndarray(dict_to_convert):
    """Turn single level dict into 2d numpy.ndarray."""

    return np.asarray(list(dict_to_convert.values()))


def lsctable_to_dict(inj_table, ifos, opts, trig_table=None, background_bestnrs=None):
    """Extract parameters from injections in lsctables. If found extract
    trigger info.

    Parameters
    ----------
    inj_table: ligo.lw.lsctables.SimInspiralTable
    ifos: list
    trig_table: glue.ligolw.lsctables.MultiInspiralTable, optional
    background_bestnrs: numpy.array, optional
        Used to compute FAP of quiet injections.

    Returns
    -------
    inj_data: dict of parameter numpy.array pairs
        Found or missed injection parameter dictionary.
    """

    inj_data = {}

    # Local copies of variables entering the BestNR definition
    chisq_index = opts.chisq_index
    chisq_nhigh = opts.chisq_nhigh
    null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
    snr_thresh = opts.snr_threshold
    sngl_snr_thresh = opts.sngl_snr_threshold
    new_snr_thresh = opts.newsnr_threshold
    null_grad_thresh = opts.null_grad_thresh
    null_grad_val = opts.null_grad_val

    # Get inj params
    inj_params = ['mchirp', 'mass1', 'mass2', 'distance', 'inclination',
                  'spin1x', 'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z']
    for param in inj_params:
        inj_data[param] = get_column(inj_table, param)
    inj_data['ra'] = get_column(inj_table, 'longitude')
    inj_data['dec'] = get_column(inj_table, 'latitude')
    inj_data['time'] = get_column(inj_table, 'geocent_end_time') \
                + 1e-9 * get_column(inj_table, 'geocent_end_time_ns')
    for ifo in ifos:
        site = ifo.lower()
        inj_data['eff_dist_{}'.format(site)] = get_column(
            inj_table, 'eff_dist_{}'.format(site))
    inj_data['eff_dist'] = np.asarray([
        1.0 / sum(1.0 / inj_data['eff_dist_{}'.format(site)] for ifo in ifos)])

    if trig_table:
        # Get recovered parameters and statistic values
        rec_params = ['mchirp', 'mass1', 'mass2', 'ra', 'dec']
        filter_stats = ['snr', 'chisq', 'bank_chisq', 'cont_chisq', 'null_snr']
        for param in rec_params:
            inj_data["rec_" + param] = get_column(trig_table, param)
        for stat in filter_stats:
            inj_data[stat] = get_column(trig_table, stat)
        for ifo in ifos:
            site = ifo.lower()
            inj_data['snr_{}'.format(site)] = trig_table.get_sngl_snr(site)
        inj_data['bestnr'] = ppu.get_bestnrs(trig_table,
                                             q=chisq_index,
                                             n=chisq_nhigh,
                                             null_thresh=null_thresh,
                                             snr_threshold=snr_thresh,
                                             sngl_snr_threshold=sngl_snr_thresh,
                                             chisq_threshold=new_snr_thresh,
                                             null_grad_thresh=null_grad_thresh,
                                             null_grad_val=null_grad_val)
        if background_bestnrs is not None:
            inj_data['fap'] = np.array(
                [sum(background_bestnrs > bestnr) for bestnr in inj_data['bestnr']],
                dtype=float) / len(background_bestnrs)

        # Antenna responses
        f_resp = {}
        # Calibration errors:
        # Get the relative detector sensitivities averaged over the
        # parameters. This is used to marginalize over calibration errors.
        inj_sigma = trig_table.get_sigmasqs()
        # If the sigmasqs are not populated, we can only do calibration
        # errors in the 1-detector case.
        for ifo in ifos:
            if sum(inj_sigma[ifo] == 0):
                logging.info("%s: sigmasq not set for at least one trigger.", ifo)
            if sum(inj_sigma[ifo] != 0) == 0:
                logging.info("%s: sigmasq not set for any trigger.", ifo)
                if len(ifos) == 1:
                    msg = "This is a single ifo analysis. "
                    msg += "Setting sigmasq to unity for all triggers."
                    logging.info(msg)
                    inj_sigma[ifo][:] = 1.0
            antenna = Detector(ifo)
            f_resp[ifo] = ppu.get_antenna_responses(antenna, inj_data['ra'],
                                                    inj_data['dec'],
                                                    inj_data['time'])
        inj_sigma_mult = dict_to_ndarray(inj_sigma) * dict_to_ndarray(f_resp)
        inj_sigma_tot = np.sum(inj_sigma_mult, axis=0)
        for ifo in ifos:
            inj_data['inj_sigma_mean_%s' % ifo.lower()] = np.mean(
                inj_sigma[ifo] * f_resp[ifo] / inj_sigma_tot)

    return inj_data


# =============================================================================
# Main script starts here
# =============================================================================
parser = ppu.pygrb_initialize_plot_parser(description=__doc__,
                                          version=__version__)
parser.add_argument("-F", "--offsource-file", action="store", required=True,
                    default=None, help="Location of off-source trigger file")
# As opposed to offsource-file and trig-file, this only contains onsource
parser.add_argument("--onsource-file", action="store", default=None,
                    help="Location of on-source trigger file.")
ppu.pygrb_add_missed_injs_input_opt(parser)
ppu.pygrb_add_injmc_opts(parser)
ppu.pygrb_add_bestnr_opts(parser)
parser.add_argument("--num-loudest-off-trigs", action="store",
                    type=int, default=30, help="Number of loudest " +
                    "offsouce triggers to output details about.")
parser.add_argument("--quiet-found-injs-output-file", default=None, #required=True,
                    help="Quiet-found injections html output file.")
parser.add_argument("--missed-found-injs-output-file", default=None, #required=True,
                    help="Missed-found injections html output file.")
parser.add_argument("--quiet-found-injs-h5-output-file", default=None, #required=True,
                    help="Quiet-found injections h5 output file.")
parser.add_argument("--loudest-offsource-trigs-output-file", default=None, #required=True,
                    help="Loudest offsource triggers html output file.")
parser.add_argument("--loudest-offsource-trigs-h5-output-file", default=None, #required=True,
                    help="Loudest offsource triggers h5 output file.")
parser.add_argument("--loudest-onsource-trig-output-file", default=None, #required=True,
                    help="Loudest onsource trigger html output file.")
parser.add_argument("--loudest-onsource-trig-h5-output-file", default=None, #required=True,
                    help="Loudest onsource trigger h5 output file.")
parser.add_argument("-g", "--glitch-check-factor", action="store",
                    type=float, default=1.0, help="When deciding " +
                    "exclusion efficiencies this value is multiplied " +
                    "to the offsource around the injection trigger to " +
                    "determine if it is just a loud glitch.")
parser.add_argument("-C", "--cluster-window", action="store", type=float,
                    default=0.1, help="The cluster window used " +
                    "to cluster triggers in time.")
opts = parser.parse_args()

init_logging(opts.verbose, format="%(asctime)s: %(levelname)s: %(message)s")

# TODO: split this script into one that handles injections
# and one that handles loudest on/off-source events?
# TODO: deal better with the multiple output file possibilities?
# The hdf5 output files are handy for injection follow-up
# but for the moment this remains a hacky solution
output_files = [opts.quiet_found_injs_output_file,
                opts.missed_found_injs_output_file,
                opts.quiet_found_injs_h5_output_file,
                opts.loudest_offsource_trigs_output_file,
                opts.loudest_offsource_trigs_h5_output_file,
                opts.loudest_onsource_trig_output_file,
                opts.loudest_onsource_trig_h5_output_file]
if output_files.count(None) == len(output_files):
    msg = "Please specify at least one output file location."
    parser.error(msg)

if opts.quiet_found_injs_output_file or opts.missed_found_injs_output_file or\
    opts.quiet_found_injs_h5_output_file:
    do_injections = True
    if (opts.found_file is None) and (opts.missed_file is None):
        err_msg = "Must provide both found and missed injections file "
        err_msg += "locations if processing injections."
        parser.error(err_msg)
else:
    do_injections = False

if opts.loudest_onsource_trig_output_file or opts.loudest_onsource_trig_h5_output_file:
    if opts.onsource_file is None:
        err_msg = "Must provide the on-source file location to output its "
        err_msg += "loudest trigger information."
        parser.error(err_msg)

if not opts.newsnr_threshold:
    opts.newsnr_threshold = opts.snr_threshold

# Store options used multiple times in local variables
trig_file = opts.offsource_file
onsource_file = opts.onsource_file
found_file = opts.found_file
missed_file = opts.missed_file
chisq_index = opts.chisq_index
chisq_nhigh = opts.chisq_nhigh
wf_err = opts.waveform_error
cal_errs = {}
cal_errs['G1'] = opts.g1_cal_error
cal_errs['H1'] = opts.h1_cal_error
cal_errs['K1'] = opts.k1_cal_error
cal_errs['L1'] = opts.l1_cal_error
cal_errs['V1'] = opts.v1_cal_error
cal_dc_errs = {}
cal_dc_errs['G1'] = opts.g1_dc_cal_error
cal_dc_errs['H1'] = opts.h1_dc_cal_error
cal_dc_errs['K1'] = opts.k1_dc_cal_error
cal_dc_errs['L1'] = opts.l1_dc_cal_error
cal_dc_errs['V1'] = opts.v1_dc_cal_error
snr_thresh = opts.snr_threshold
sngl_snr_thresh = opts.sngl_snr_threshold
new_snr_thresh = opts.newsnr_threshold
null_grad_thresh = opts.null_grad_thresh
null_grad_val = opts.null_grad_val
null_thresh = list(map(float, opts.null_snr_threshold.split(',')))
upper_dist = opts.upper_inj_dist
lower_dist = opts.lower_inj_dist
num_bins = opts.num_bins
wav_err = opts.waveform_error
cluster_window = opts.cluster_window
glitch_check_fac = opts.glitch_check_factor
num_mc_injs = opts.num_mc_injections
qf_outfile = opts.quiet_found_injs_output_file
mf_outfile = opts.missed_found_injs_output_file
lofft_outfile = opts.loudest_offsource_trigs_output_file
lont_outfile = opts.loudest_onsource_trig_output_file
qf_h5_outfile = opts.quiet_found_injs_h5_output_file
lofft_h5_outfile = opts.loudest_offsource_trigs_h5_output_file
lont_h5_outfile = opts.loudest_onsource_trig_h5_output_file

# Set output directories
logging.info("Setting output directory.")
for output_file in output_files:
    if output_file:
        outdir = os.path.split(os.path.abspath(output_file))[0]
        if not os.path.isdir(outdir):
            os.makedirs(outdir)

# Initialize random number generator
np.random.seed(opts.seed)
logging.info("Setting random seed to %d.", opts.seed)

# Extract IFOs and vetoes
ifos, vetoes = ppu.extract_ifos_and_vetoes(trig_file, opts.veto_files,
                                           opts.veto_category)

# Load triggers, time-slides, and segment dictionary
logging.info("Loading triggers.")
trigs = ppu.load_xml_table(trig_file, lsctables.MultiInspiralTable.tableName)
logging.info("%d triggers loaded.", len(trigs))
logging.info("Loading timeslides.")
slide_dict = ppu.load_time_slides(trig_file)
logging.info("Loading segments.")
segment_dict = ppu.load_segment_dict(trig_file)

# Construct trials
logging.info("Constructing trials.")
trial_dict = ppu.construct_trials(opts.seg_files, segment_dict,
                                  ifos, slide_dict, vetoes)
total_trials = sum([len(trial_dict[slide_id]) for slide_id in slide_dict])
logging.info("%d trials generated.", total_trials)

# Extract basic trigger properties and store as dictionaries
trig_time, trig_snr, trig_bestnr = \
    ppu.extract_basic_trig_properties(trial_dict, trigs, slide_dict, segment_dict, opts)

# Calculate SNR and BestNR values and maxima
time_veto_max_snr = {}
time_veto_max_bestnr = {}
for slide_id in slide_dict:
    num_slide_segs = len(trial_dict[slide_id])
    time_veto_max_snr[slide_id] = np.zeros(num_slide_segs)
    time_veto_max_bestnr[slide_id] = np.zeros(num_slide_segs)

for slide_id in slide_dict:
    for j, trial in enumerate(trial_dict[slide_id]):
        trial_cut = (trial[0] <= trig_time[slide_id])\
                          & (trig_time[slide_id] < trial[1])
        if not trial_cut.any():
            continue
        # Max SNR
        time_veto_max_snr[slide_id][j] = \
                        max(trig_snr[slide_id][trial_cut])
        # Max BestNR
        time_veto_max_bestnr[slide_id][j] = \
                        max(trig_bestnr[slide_id][trial_cut])
        # Max SNR for triggers passing SBVs
        sbv_cut = trig_bestnr[slide_id] != 0
        if not (trial_cut&sbv_cut).any():
            continue

logging.info("SNR and bestNR maxima calculated.")

# Output details of loudest offsouce triggers, sorted by BestNR
offsource_trigs = []
sorted_trigs = ppu.sort_trigs(trial_dict, trigs, slide_dict, segment_dict)
for slide_id in slide_dict:
    offsource_trigs.extend(zip(trig_bestnr[slide_id], sorted_trigs[slide_id]))
offsource_trigs.sort(key=lambda element: element[0])
offsource_trigs.reverse()

# Median and max values of SNR and BestNR
_, median_snr, _ = ppu.max_median_stat(slide_dict, time_veto_max_snr, trig_snr, total_trials)
max_bestnr, median_bestnr, full_time_veto_max_bestnr =\
    ppu.max_median_stat(slide_dict, time_veto_max_bestnr, trig_bestnr, total_trials)

if lofft_outfile or lofft_h5_outfile:
    # td: table data
    td = []

    # NOTE: Rather than changing the xml table structures, KAGRA piggy backs
    # on the TAMA column (t). The h2 column is left unused.
    # TAMA and Hanford2 are therefore no longer supported.
    ifo_att = {'G1':'g', 'H1':'h1', 'K1': 't', 'L1':'l', 'V1':'v'}

    # Gather properties of the loudest offsource triggers
    for i in range(min(len(offsource_trigs), opts.num_loudest_off_trigs)):
        bestnr = offsource_trigs[i][0]
        trig = offsource_trigs[i][1]
        trig_slide_id = int(trig.time_slide_id)

        # Get trial of trigger, triggers with 'No trial' should have been removed!
        for j, trial in enumerate(trial_dict[trig_slide_id]):
            if trig.get_end() in trial:
                chunk_num = j
                break
        else:
            chunk_num = 'No trial'

        # Get FAP of trigger
        num_trials_louder = 0
        for slide_id in slide_dict:
            for val in time_veto_max_bestnr[slide_id]:
                if val > bestnr:
                    num_trials_louder += 1
        fap = num_trials_louder/total_trials
        pval = '< %.3g' % (1./total_trials) if fap == 0 else '%.3g' % fap

        d = [chunk_num, trig_slide_id, pval,\
             np.asarray(trig.get_end()).astype(float),\
             trig.mass1, trig.mass2, trig.mchirp,\
             (np.degrees(trig.ra)), (np.degrees(trig.dec)),\
             trig.snr, trig.chisq, trig.bank_chisq,\
             trig.cont_chisq, trig.get_null_snr()]
        d.extend([getattr(trig, 'snr_%s' % ifo_att[ifo])\
                         for ifo in ifos])
        d.extend([slide_dict[trig_slide_id][ifo] for ifo in ifos])
        d.append(bestnr)
        td.append(d)

    # th: table header
    th = ['Trial', 'Slide Num', 'p-value', 'GPS time',\
          'Rec. m1', 'Rec. m2', 'Rec. Mc',\
          'Rec. RA', 'Rec. Dec', 'SNR', 'Chi^2', 'Bank veto', 'Auto veto',\
          'Null SNR']
    th.extend(['%s SNR' % ifo for ifo in ifos])
    th.extend(['%s time shift (s)' % ifo for ifo in ifos])
    th.append('BestNR')

    # To ensure desired formatting in the h5 file and html table:
    # 1) "transpose" the data preserving its dtype
    td = list(zip(*td))

    # Write to h5 file
    if lofft_outfile:
        logging.info("Writing %d lousdest offsource triggers to h5 file.", len(td))
        lofft_h5_fp = h5py.File(lofft_h5_outfile, 'w')
        for i, key in enumerate(th):
            lofft_h5_fp.create_dataset(key, data=td[i])
        lofft_h5_fp.close()

    # Write to html file
    if lofft_outfile:
        logging.info("Writing %d loudest triggers to html file.", len(td))

        # To ensure desired formatting in the html table:
        # 2) convert the columns to numpy arrays
        # This is necessary as the p-values need to be treated as strings,
        # because they may contain a '<'
        td = [np.asarray(d) for d in td]

        # Format of table data
        format_strings = ['##.##', '##.##', None, '##.#####',
                          '##.##', '##.##', '##.##',
                          '##.##', '##.##', '##.##', '##.##', '##.##', '##.##',
                          '##.##']
        format_strings.extend(['##.##' for ifo in ifos])
        format_strings.extend(['##.##' for ifo in ifos])
        format_strings.extend(['##.##'])
        html_table = pycbc.results.html_table(td, th,
                                              format_strings=format_strings,
                                              page_size=30)
        kwds = {'title' : "Parameters of loudest offsource triggers",
                'caption' : "Parameters of the "+
                            str(min(len(offsource_trigs), opts.num_loudest_off_trigs))+
                            " loudest offsource triggers.  "+
                            "The median reweighted SNR value is "+
                            str(median_bestnr)+
                            ".  The median SNR value is "+
                            str(median_snr),
                'cmd' :' '.join(sys.argv), }
        pycbc.results.save_fig_with_metadata(str(html_table), lofft_outfile, **kwds)

    # Store BestNR and FAP values: for collective FAP value studies at the
    # end of an observing run collectively
    # TODO: Needs a final place in the results webpage
    # np.savetxt('%s/bestnr_vs_fap_numbers.txt' %(outdir),
    #            full_time_veto_max_bestnr, delimiter='/t')

# =======================
# Load on source triggers
# =======================
if onsource_file:

    # Get trigs
    on_trigs = ppu.load_xml_table(onsource_file, "multi_inspiral")
    logging.info("%d onsource triggers loaded.", len(on_trigs))

    # Separate off chirp mass column
    on_mchirp = on_trigs.get_column('mchirp')

    # Record loudest trig by BestNR
    loud_on_bestnr = 0
    if on_trigs:
        on_trigs_bestnrs = ppu.get_bestnrs(on_trigs,
                                           q=chisq_index,
                                           n=chisq_nhigh,
                                           null_thresh=null_thresh,
                                           snr_threshold=snr_thresh,
                                           sngl_snr_threshold=sngl_snr_thresh,
                                           chisq_threshold=new_snr_thresh,
                                           null_grad_thresh=null_grad_thresh,
                                           null_grad_val=null_grad_val)
        loud_on_bestnr_trigs, loud_on_bestnr = \
            np.asarray([[x, y] for y, x in sorted(zip(on_trigs_bestnrs, on_trigs),
                                                  key=lambda on_trig: on_trig[0],
                                                  reverse=True)])[0]

    # If the loudest event has bestnr = 0, there is no event at all!
    if loud_on_bestnr == 0:
        loud_on_bestnr_trigs = None

    logging.info("Onsource analysed.")

    # Table data
    td = []

    # Gather data
    loud_on_fap = 1
    if loud_on_bestnr_trigs:
        trig = loud_on_bestnr_trigs
        num_trials_louder = 0
        tot_off_snr = np.array([])
        for slide_id in slide_dict:
            num_trials_louder += sum(time_veto_max_bestnr[slide_id] > \
                                     loud_on_bestnr)
            tot_off_snr = np.concatenate([tot_off_snr,\
                                          time_veto_max_bestnr[slide_id]])
        fap = num_trials_louder/total_trials
        fap_test = sum(tot_off_snr > loud_on_bestnr)/total_trials
        pval = '< %.3g' % (1./total_trials) if fap == 0 else '%.3g' % fap
        loud_on_fap = fap
        d = [pval, np.asarray(trig.get_end()).astype(float),\
             trig.mass1, trig.mass2, trig.mchirp,\
             np.degrees(trig.ra), np.degrees(trig.dec),\
             trig.snr, trig.chisq, trig.bank_chisq,\
             trig.cont_chisq, trig.get_null_snr()] + \
             [trig.get_sngl_snr(ifo) for ifo in ifos] + [loud_on_bestnr]
        td.append(d)
    else:
        td.append(["There are no events"] + [0 for number in range(11)] + \
                  [0 for ifo in ifos] + [0])

    # Table header
    th = ['p-value', 'GPS time', 'Rec. m1', 'Rec. m2', 'Rec. Mc', 'Rec. RA',\
          'Rec. Dec', 'SNR', 'Chi^2', 'Bank veto', 'Auto veto', 'Null SNR'] +\
         ['%s SNR' % ifo for ifo in ifos] + ['BestNR']

    td = list(zip(*td))

    # Write to h5 file
    if lont_h5_outfile:
        logging.info("Writing loudest onsource trigger to h5 file.")
        with h5py.File(lont_h5_outfile, 'w') as lont_h5_fp:
            for i, key in enumerate(th):
                lont_h5_fp.create_dataset(key, data=td[i])

    # Write to html file
    if lont_outfile:
        logging.info("Writing loudest onsource trigger to html file.")

        # Format of table data
        format_strings = [None, '##.#####', '##.##', '##.##', '##.##', '##.##',
                          '##.##', '##.##', '##.##', '##.##', '##.##', '##.##']
        format_strings.extend(['##.##' for ifo in ifos])
        format_strings.extend(['##.##'])

        # Table data
        td = [np.asarray(d) for d in td]
        html_table = pycbc.results.html_table(td, th,
                                              format_strings=format_strings,
                                              page_size=1)
        kwds = {'title' : "Loudest event",
                'caption' : "Recovered parameters and statistic values of the loudest trigger.",
                'cmd' :' '.join(sys.argv), }
        pycbc.results.save_fig_with_metadata(str(html_table), lont_outfile, **kwds)
else:
    tot_off_snr = np.array([])
    for slide_id in slide_dict:
        tot_off_snr = np.concatenate([tot_off_snr,\
                                      time_veto_max_bestnr[slide_id]])
    med_snr = np.median(tot_off_snr)
    fap = sum(tot_off_snr > med_snr)/total_trials

# =======================
# Post-process injections
# =======================
if do_injections:

    found_trig_table = ppu.load_injections(found_file, vetoes)
    found_inj_table = ppu.load_injections(found_file, vetoes, sim_table=True)

    logging.info("Missed/found injections/triggers loaded.")

    # Extract columns of found injections and triggers
    found_injs = lsctable_to_dict(
        found_inj_table, ifos, opts, trig_table=found_trig_table,
        background_bestnrs=full_time_veto_max_bestnr
        )

    # Construct conditions for injection:
    # 1) found louder than background,
    zero_fap = found_injs['bestnr'] > max_bestnr

    # 2) found (bestnr > 0) but not louder than background (non-zero FAP)
    nonzero_fap = ~zero_fap & (found_injs['bestnr'] != 0)

    # 3) missed after being recovered (i.e., vetoed)
    # -- > question: is there ever another way this happens other than veto?
    # vetoed_trigs = (~zero_fap) & (~nonzero_fap)
    vetoed_trigs = found_injs['bestnr'] == 0

    # Used to separate triggers into:
    # 1) zero_fap 'g_found' -- > now acessed by indices: zero_fap
    # 2) nonzero_fap 'g_ifar' --> now acessed by indices: nonzero_fap
    # 3) missed because of vetoes 'g_missed2'--> now accessed by indices: vetoed_trigs

    logging.info("%d found injections analysed.", len(found_injs['mchirp']))

    # Missed injections (ones not recovered at all)
    missed_inj_table = ppu.load_injections(missed_file, vetoes, sim_table=True)
    missed_injs = lsctable_to_dict(missed_inj_table, ifos, opts)

    # Avoids a problem with formatting in the non-static html output file
    missed_na = [-0] * len(missed_injs['mchirp'])

    logging.info("%d missed injections analysed.", len(missed_injs['mchirp']))

    # Efficiency marginalization starts here
    # Create new set of injections for efficiency calculations
    total_injs = len(found_injs['mchirp']) + len(missed_injs['mchirp'])
    long_inj = {}
    long_inj['dist'] = stats.uniform.rvs(size=total_injs) * (upper_dist-lower_dist) +\
                  upper_dist

    logging.info("%d long distance injections created.", total_injs)

    # Set distance bins and data arrays
    dist_bins = zip(np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),\
                             (upper_dist-lower_dist)/num_bins),\
                   np.arange(lower_dist, upper_dist + (upper_dist-lower_dist),\
                             (upper_dist-lower_dist)/num_bins) +\
                             (upper_dist-lower_dist)/num_bins)
    dist_bins = list(dist_bins)

    num_dist_bins_plus_one = len(dist_bins) + 1
    num_injections = {}
    found_max_bestnr = {}
    found_on_bestnr = {}
    for key in ['mc', 'no_mc']:
        num_injections[key] = np.zeros(num_dist_bins_plus_one)
        found_max_bestnr[key] = np.zeros(num_dist_bins_plus_one)
        found_on_bestnr[key] = np.zeros(num_dist_bins_plus_one)

    # Calculate the amplitude error
    # Begin by calculating the components from each detector
    cal_error = 0
    for ifo in ifos:
        cal_error += cal_errs[ifo]**2 * found_injs['inj_sigma_mean_%s' % ifo.lower()]**2
    cal_error = cal_error**0.5

    max_dc_cal_error = max(cal_dc_errs.values())

    # Calibration phase uncertainties are neglected
    logging.info("Calibration amplitude uncertainty calculated.")

    # Now create the numbers for the efficiency plots.  These include calibration
    # and waveform errors incorporated by running over each injection num_mc_injs
    # times, where each time we draw a random value of distance.

    # Distribute injections
    found_injs['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs,
                                               found_injs['distance'],
                                               cal_error, wav_err,
                                               max_dc_cal_error)
    missed_injs['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs,
                                                missed_injs['distance'],
                                                cal_error, wav_err,
                                                max_dc_cal_error)
    long_inj['dist_mc'] = ppu.mc_cal_wf_errs(num_mc_injs, long_inj['dist'],
                                             cal_error, wav_err, max_dc_cal_error)
    logging.info("MC injection set distributed with %d iterations.",\
                 num_mc_injs)

    # Check injections against on source
    if onsource_file:
        more_sig_than_onsource = (found_injs['fap'] <= loud_on_fap)
    else:
        more_sig_than_onsource = (found_injs['fap'] <= 0.5)
    more_sig_than_onsource = more_sig_than_onsource.all(0)

    distance_count = np.zeros(len(dist_bins))

    max_bestnr_cut = (found_injs['bestnr'] > max_bestnr)

    # Check louder than on source
    if onsource_file:
        on_bestnr_cut = found_injs['bestnr'] > loud_on_bestnr
    else:
        on_bestnr_cut = found_injs['bestnr'] > med_snr

    # Check whether injection is found for the purposes of exclusion
    # distance calculation.
    # Found: if louder than all on source
    # Missed: if not louder than loudest on source
    found_excl = on_bestnr_cut & (more_sig_than_onsource) & \
                (found_injs['bestnr'] != 0)
    # If not missed, double check bestnr against nearby triggers
    near_test = np.zeros((found_excl).sum()).astype(bool)
    for j, (t, bestnr) in enumerate(zip(found_injs['time'][found_excl],\
                                        found_injs['bestnr'][found_excl])):
        # 0 is the zero-lag timeslide
        near_bestnr = trig_bestnr[0]\
                      [np.abs(trig_time[0]-t) < cluster_window]
        near_test[j] = ~((near_bestnr * glitch_check_fac > bestnr).any())
    # Apply the local test
    c = 0
    for z, b in enumerate(found_excl):
        if found_excl[z]:
            found_excl[z] = near_test[c]
            c += 1

    # Loop over each random instance of the injection set
    for k in range(num_mc_injs+1):
        # Loop over the distance bins
        for j, dist_bin in enumerate(dist_bins):
            # Construct distance cut
            found_dist_cut = (dist_bin[0] <= found_injs['dist_mc'][k, :]) &\
                             (found_injs['dist_mc'][k, :] < dist_bin[1])
            missed_dist_cut = (dist_bin[0] <= missed_injs['dist_mc'][k, :]) &\
                              (missed_injs['dist_mc'][k, :] < dist_bin[1])
            long_dist_cut = (dist_bin[0] <= long_inj['dist_mc'][k, :]) &\
                            (long_inj['dist_mc'][k, :] < dist_bin[1])

            # Count all injections in this distance bin
            num_found_pass = (found_dist_cut).sum()
            num_missed_pass = (missed_dist_cut).sum()
            num_long_pass = long_dist_cut.sum() or 0
            # Count only zero FAR injections
            num_zero_far = (found_dist_cut & max_bestnr_cut).sum()
            # Count number found for exclusion
            num_excl = (found_dist_cut & found_excl).sum()

            # Record number of injections, number found for exclusion
            # and number of zero FAR
            if k == 0:
                key = 'no_mc'
            else:
                key = 'mc'
            num_pass = num_found_pass + num_missed_pass + num_long_pass
            num_injections[key][j] += num_pass
            num_injections[key][-1] += num_pass
            found_max_bestnr[key][j] += num_zero_far
            found_max_bestnr[key][-1] += num_zero_far
            found_on_bestnr[key][j] += num_excl
            found_on_bestnr[key][-1] += num_excl

    logging.info("Found/missed injection efficiency calculations completed.")

    # Write quiet triggers to file
    sites = [ifo[0] for ifo in ifos]
    th = ['Dist'] + ['Eff. Dist. %s' % site for site in sites] +\
         ['GPS time'] +\
         ['Inj. m1', 'Inj. m2', 'Inj. Mc', 'Rec. m1', 'Rec. m2', 'Rec. Mc',\
          'Inj. inc', 'Inj. RA', 'Inj. Dec', 'Rec. RA', 'Rec. Dec', 'SNR',\
          'Chi^2', 'Bank veto', 'Auto veto', 'Null SNR'] +\
         ['SNR %s' % ifo for ifo in ifos] +\
         ['BestNR', 'Inj S1x', 'Inj S1y', 'Inj S1z',\
                    'Inj S2x', 'Inj S2y', 'Inj S2z']
    # Format of table data
    format_strings = ['##.##']
    format_strings.extend(['##.##' for ifo in ifos])
    format_strings.extend(['##.#####',
                           '##.##', '##.##', '##.##',
                           '##.##', '##.##', '##.##',
                           '##.##', '##.##', '##.##',
                           '##.##', '##.##', '##.##', '##.##', '##.##', '##.##', '##.##'])
    format_strings.extend(['##.##' for ifo in ifos])
    format_strings.extend(['##.##',
                           '##.##', '##.##', '##.##', '##.##', '##.##',
                           '##.##'])
    sngl_snr_keys = ['snr_{}'.format(ifo.lower()) for ifo in ifos]
    keys = ['distance']
    keys += ['eff_dist_{}'.format(ifo.lower()) for ifo in ifos]
    keys += ['time', 'mass1', 'mass2', 'mchirp', 'rec_mass1', 'rec_mass2',
             'rec_mchirp', 'inclination', 'ra', 'dec', 'rec_ra', 'rec_dec',
             'snr', 'chisq', 'bank_chisq', 'cont_chisq', 'null_snr']
    keys += sngl_snr_keys
    keys += ['bestnr', 'spin1x', 'spin1y', 'spin1z', 'spin2x', 'spin2y', 'spin2z']
    # The following parameters are available only for recovered injections
    na_keys = ['rec_mass1', 'rec_mass2', 'rec_mchirp', 'rec_ra', 'rec_dec',
               'snr', 'chisq', 'bank_chisq', 'cont_chisq', 'null_snr', 'bestnr']
    na_keys += sngl_snr_keys
    td = []
    for key in keys:
        if key in na_keys:
            td += [np.concatenate((found_injs[key][nonzero_fap],\
                                   found_injs[key][vetoed_trigs],\
                                   missed_na))]
        else:
            td += [np.concatenate((found_injs[key][nonzero_fap],\
                                   found_injs[key][vetoed_trigs],\
                                   missed_injs[key]))]
    td = list(zip(*td))
    td.sort(key=lambda elem: elem[0])
    td = list(zip(*td))
    # Write to h5 file
    if qf_h5_outfile:
        logging.info("Writing %d quiet-found injections to h5 file.", len(td))
        with h5py.File(qf_h5_outfile, 'w') as qf_h5_fp:
            for i, key in enumerate(th):
                qf_h5_fp.create_dataset(key, data=td[i])

    # Write to html file
    if qf_outfile:
        logging.info("Writing %d quiet-found injections to html file.", len(td))
        td = [np.asarray(d) for d in td]
        html_table = pycbc.results.html_table(td, th,
                                              format_strings=format_strings,
                                              page_size=20)
        kwds = {'title' : "Quiet found injections",
                'caption' : "Recovered parameters and statistic values of \
                injections that are recovered, but not louder than background.",
                'cmd' :' '.join(sys.argv), }
        pycbc.results.save_fig_with_metadata(str(html_table), qf_outfile, **kwds)

    # Write to html file
    if mf_outfile:
        t_missed = []
        for key in keys:
            t_missed += [found_injs[key][vetoed_trigs]]
        t_missed = list(zip(*t_missed))
        t_missed.sort(key=lambda elem: elem[0])
        logging.info("Writing %d missed-found injections to html file.", len(t_missed))

        t_missed = zip(*t_missed)
        t_missed = [np.asarray(d) for d in t_missed]
        html_table = pycbc.results.html_table(t_missed, th,
                                              format_strings=format_strings,
                                              page_size=20)
        kwds = {'title' : "Missed found injections",
                'caption' : "Recovered parameters and statistic values of \
                injections that are recovered, but downwieghted to BestNR = 0 \
                (i.e., vetoed).",
                'cmd' :' '.join(sys.argv), }
        pycbc.results.save_fig_with_metadata(str(html_table), mf_outfile, **kwds)

# Post-processing of injections ends here
back to top