Revision 29df1b564abf7a6ed5230a5e97a85da70c403c2a authored by Ian Harry on 15 June 2017, 14:00:24 UTC, committed by Duncan Brown on 15 June 2017, 14:00:24 UTC
1 parent 23fd945
Raw File
pycbc_fit_sngl_trigs
#!/usr/bin/python

# Copyright 2015 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np

from pycbc import io, events, bin_utils
from pycbc.events import trigger_fits as trstats
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

stat_dict = {
    "new_snr"       : events.newsnr,
    "effective_snr" : events.effsnr,
    "snr"           : lambda snr, rchisq : snr,
    "snronchi"      : lambda snr, rchisq : snr / (rchisq ** 0.5)
}

def get_stat(statchoice, snr, rchisq, fac):
    if fac is not None:
        if statchoice not in ["new_snr", "effective_snr"]:
            raise RuntimeError("Can't use --stat-factor with this statistic!")
        return stat_dict[statchoice](snr, rchisq, fac)
    else:
        return stat_dict[statchoice](snr, rchisq)

def get_bins(opt, pmin, pmax):
    if opt.bin_spacing == "linear":
        return bin_utils.LinearBins(pmin, pmax, opt.num_bins)
    elif opt.bin_spacing == "log":
        return bin_utils.LogarithmicBins(pmin, pmax, opt.num_bins)
    elif opt.bin_spacing == "irregular":
        return bin_utils.IrregularBins(opt.bin_edges)

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                "distributions to various functions")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--inputs", nargs="+",
                    help="Input file or space-separated list of input files "
                    "containing single triggers.  Currently .xml(.gz) "
                    "and .hdf supported.  Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters; required "
                    "if fitting over trigger masses and/or spins with hdf "
                    "format triggers")
parser.add_argument("--veto-file", default=None,
                    help="File in .xml format with veto segments to apply to "
                    "triggers before fitting")
parser.add_argument("--veto-segment-name", default=None,
                    help="Name of veto segments to apply. Optional: if not "
                    "given, all segments for a given ifo will be used")
parser.add_argument("--output", required=True,
                    help="Location for output file containing fit coefficients"
                    ".  Required")
parser.add_argument("--plot-dir", default=None,
                    help="Plot the fits made, the variation of fitting "
                    "coefficients and the Kolmogorov-Smirnov test values "
                    "and save to the specified directory")
parser.add_argument("--user-tag", default="",
                    help="Put a possibly informative string in the names of "
                    "plot files")
parser.add_argument("--ifos", nargs="+",
                    help="Ifo or space-separated list of ifos to select "
                    "triggers to be fit.  Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=["snr", "snronchi", "effective_snr", "new_snr"],
                    help="Function of SNR and chisq to perform fits with")
# FIXME: how to parse various possible mixtures of chisq
#parser.add_argument("--chisq-choice", default="power",
#                    choices=["power","bank","auto","maxpowerbank"],
#                    help="Which chisq value or values to form the "
#                    "single-trigger statistic with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics.  Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
default_thresh = 5.5
cuts = parser.add_mutually_exclusive_group()
cuts.add_argument("--snr-threshold", type=float,
                  help="Only fit triggers with SNR above this threshold. "
                  "Default %.1f" % default_thresh)
cuts.add_argument("--trig-filter", default=None,
                  help="Filter function applied to trigger properties: only "
                  "implemented for hdf format.  May include names from math, "
                  "numpy as 'np', pycbc.events and pycbc.pnutils.  Ex. "
                  "'events.newsnr(self.snr, self.bank_chisq/self.bank_chisq_"
                  "dof) > 6.)'")
parser.add_argument("--stat-threshold", nargs="+", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold : can be a space-separated list, then a fit "
                    "will be done for each threshold.  Required.  Typical "
                    "values 6.5 6.75 7")
parser.add_argument("--fit-param", required=True,
                    help="Parameter over which to estimate variation of the "
                    "fits. Required. Must be a SnglInspiralTable column for "
                    ".xml input or the name of a dataset for .hdf")
# FIXME: allow for math functions of columns. Ex. 1./mtotal")
parser.add_argument("--bin-spacing", choices=["linear", "log", "irregular"],
                    help="How to space parameter bin edges")
binopt = parser.add_mutually_exclusive_group(required=True)
binopt.add_argument("--num-bins", type=int,
                    help="Number of regularly spaced bins to use over the "
                    " parameter")
binopt.add_argument("--irregular-bins",
                    help="Comma-separated list of parameter bin edges. "
                    "Required if --bin-spacing = irregular")

opt = parser.parse_args()

if opt.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

statname = opt.sngl_stat.replace("_", " ")
paramname = opt.fit_param.replace("_", " ")
paramtag = opt.fit_param.replace("_", "")

if opt.plot_dir is not None:
    outdir = opt.plot_dir if opt.plot_dir.endswith('/') else opt.plot_dir+'/'

outfile = open(opt.output, 'w')

## Check option logic
if opt.bin_spacing == "irregular":
    if opt.irregular_bins is None:
        raise RuntimeError("Must specify a list of irregular bin edges!")
    else:
        opt.bin_edges = [float(b) for b in opt.irregular_bins.split(',')]

if opt.inputs[0].split('.')[-1] == "xml" or \
       opt.inputs[0].split('.')[-2:] == ["xml", "gz"]:
    trigformat = "xml"
elif "hdf" in opt.inputs[0]:
    trigformat = "hdf"
    # columns for reading triggers without template file info
    cols = ["snr", "chisq", "template_duration", "template_id", "end_time"]
else:
    print 'inputs:', opt.inputs
    raise RuntimeError("I didn't recognize the file format for the inputs!")

if opt.trig_filter and trigformat != "hdf":
    raise NotImplementedError("Trigger filtering not available for xml!")
# hack default values of SNR threshold and filter function
if not opt.snr_threshold and not opt.trig_filter:
    opt.snr_threshold = default_thresh
    filter_func = 'self.snr > %f' % opt.snr_threshold
elif opt.snr_threshold:
    filter_func = 'self.snr > %f' % opt.snr_threshold
else:
    filter_func = opt.trig_filter

# initialize result storage
parbins = {}
counts = {}
templates = {}
fits = {}
stdev = {}
ks_prob = {}

histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]

# header
outfile.write("# ifo threshold lower upper templates triggers alpha sig_alpha ks_prob\n")

done_vetoing = False
for ifo in opt.ifos:
    # FIXME: Would like a uniform input interface from trigger files
    # to the 'get_column' syntax used by SnglInspiralUtils
    if trigformat == "xml":
        from pylal import SnglInspiralUtils as sniuls
        sngls = sniuls.ReadSnglInspiralFromFiles(opt.inputs,
            filterFunc=lambda s: s.snr>opt.snr_threshold,
            verbose=opt.verbose)
        sngls.ifocut(ifo, inplace=True)
    elif trigformat == "hdf":
        if opt.bank_file:
            # io function can currently only handle 1 file
            sngls = io.hdf.SingleDetTriggers(opt.inputs[0], 
                  opt.bank_file, opt.veto_file, opt.veto_segment_name,
                  filter_func, ifo)
            done_vetoing = True
        else:
            sngls = io.hdf.DataFromFiles(opt.inputs, group=ifo,
                  columnlist=cols,
                  filter_func=filter_func)
    if opt.plot_dir:
        plotbase = outdir + ifo + "-" + opt.user_tag

    snr = sngls.get_column("snr")
    if not len(snr):
        logging.warn("no trigs, skipping %s" % ifo)
        continue
    chisq = sngls.get_column("chisq")
    chisq_dof = sngls.get_column("chisq_dof")
    # template_id or mass1,mass2 are used to count templates
    if trigformat == "hdf":
        tid = sngls.get_column("template_id")
        mass1 = None
        mass2 = None
    else:
        tid = None
        mass1 = sngls.get_column("mass1")
        mass2 = sngls.get_column("mass2")
    parvals = sngls.get_column(opt.fit_param)
    rchisq = chisq / (2 * chisq_dof - 2)
    #bank_rchisq = sngls.get_column('bank_chisq') / sngls.get_column('bank_chisq_dof')
    #rchisq = np.maximum(rchisq, bank_rchisq)

    if opt.veto_file and not done_vetoing:
        logging.info("Applying vetoes from %s for ifo %s" % 
                     (opt.veto_file, ifo))
        keep_idx, segs = events.veto.indices_outside_segments(
                         sngls.get_column("end_time").astype(int), 
                         [opt.veto_file], ifo=ifo,
                         segment_name=opt.veto_segment_name)
        snr = snr[keep_idx]
        rchisq = rchisq[keep_idx]
        if mass1 is not None: mass1 = mass1[keep_idx]
        if mass2 is not None: mass2 = mass2[keep_idx]
        if tid is not None: tid = tid[keep_idx]
        parvals = parvals[keep_idx]
        logging.info("%i trigs remain" % len(snr))

    logging.info("Calculating ranking statistic values")
    statv = get_stat(opt.sngl_stat, snr, rchisq, opt.stat_factor)

    logging.info("Parameter range of triggers: %f - %f" %
                                                  (min(parvals), max(parvals)))
    if opt.bin_spacing == "irregular":
        logging.info("Removing triggers outside bin range %f - %f" %
                                      (min(opt.bin_edges), max(opt.bin_edges)))
        in_range = np.logical_and(parvals >= min(opt.bin_edges),
                                  parvals <= max(opt.bin_edges))
        statv = statv[in_range]
        parvals = parvals[in_range]
        logging.info("%i remain" % len(statv))
    pbins = get_bins(opt, 0.99 * min(parvals), 1.01 * max(parvals))
    # list of bin indices
    binind = [pbins[c] for c in pbins.centres()]
    logging.info("Assigning trigger param values to bins")
    # FIXME: This is slow!! either find a better way of using pylal.rate
    # or write faster binning routine
    pind = np.array([pbins[par] for par in parvals])

    logging.info("Getting max counts in bins")
    # determine trigger counts first to get plot limits to make them the same
    # for all thresholds; use only the smallest threshold requested
    minth = min(opt.stat_threshold)
    bincounts = []
    for i in binind:
        vals_inbin = statv[pind == i]
        bincounts.append(sum(vals_inbin >= minth))
    maxcount = max(bincounts)
    plotrange = np.linspace(0.95 * min(statv), 1.05 * max(statv), 100)

    templates[ifo] = {}
    counts[ifo] = {}
    fits[ifo] = {}
    stdev[ifo] = {}
    ks_prob[ifo] = {}

    for th in opt.stat_threshold:
        logging.info("Fitting above threshold %f" % th)
        counts[ifo][th] = {}
        fits[ifo][th] = {}
        stdev[ifo][th] = {}
        ks_prob[ifo][th] = {}

        for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
            # determine number of templates generating the triggers involved
            # for hdf5, use the template id; otherwise use masses
            if trigformat == "hdf":
                tid_inbin = tid[pind == i]
                numtmpl = len(set(tid_inbin))
            else:
                mass1_inbin = mass1[pind == i]
                mass2_inbin = mass2[pind == i]
                mass_tuples = [(m1, m2) for m1, m2 in 
                                                 zip(mass1_inbin, mass2_inbin)]
                numtmpl = len(set(mass_tuples))
            templates[ifo][i] = numtmpl
            vals_inbin = statv[pind == i]
            counts[ifo][th][i] = sum(vals_inbin >= th)
            if len(vals_inbin) == 0:
                print "No trigs in bin", lower, upper
                continue
            # FIXME: Also allow to use dynamically determined threshold
            # calculated via Nth loudest trig
            #thresh = trstats.tail_threshold(vals_inbin, N=nmax)
            #counts[ifo][th][i] = nmax
            # do the fit
            alpha, sig_alpha = trstats.fit_above_thresh(
                                              opt.fit_function, vals_inbin, th)
            #alpha, sig_alpha = trstats.fit_above_thresh(
            #                             opt.fit_function, vals_inbin, thresh)
            fits[ifo][th][i] = alpha
            stdev[ifo][th][i] = sig_alpha
            _, ks_prob[ifo][th][i] = trstats.KS_test(
                                       opt.fit_function, vals_inbin, alpha, th)
            outfile.write("%s %.2f %.3g %.3g %d %d %.3f %.3f %.3g\n" % 
              (ifo, th, lower, upper, numtmpl, counts[ifo][th][i], alpha,
               sig_alpha, ks_prob[ifo][th][i]))

            # add histogram to plot
            if opt.plot_dir:
                histcounts, edges = np.histogram(vals_inbin, bins=50)
                cum_counts = histcounts[::-1].cumsum()[::-1]
                binlabel = r"%.2g - %.2g" % (lower, upper)
                plt.semilogy(edges[:-1], cum_counts, linewidth=2, 
                             color=histcolors[i], label=binlabel, alpha=0.6)

                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha, th),
                  "--", color=histcolors[i],
                  label=r"$\alpha = $%.2f $\pm$ %.2f" % (alpha, sig_alpha))
                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha + \
                  sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha - \
                  sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
 
        if opt.plot_dir:
            leg = plt.legend(labelspacing=0.2)
            plt.setp(leg.get_texts(), fontsize=11)
            plt.ylim(0.7, 2*maxcount)
            plt.xlim(0.9*min(opt.stat_threshold), 1.1*max(plotrange))
            plt.grid()
            plt.title(ifo + " " + statname + " distribution split by " + \
                  paramname)
            plt.xlabel(statname, size="large")
            plt.ylabel("Cumulative number", size="large")
            dest = plotbase + "_" + opt.sngl_stat + "_cdf_by_" + \
              paramtag[0:3] + "_fit_thresh_" + str(th) + ".png"
            logging.info("Saving plot to %s" % dest)
            plt.savefig(dest)
            plt.close()

# make plots of alpha and KS significance for ifos having triggers
if opt.plot_dir:
    for ifo in fits.keys():
        for th in opt.stat_threshold:
            plt.errorbar(pbins.centres(), [fits[ifo][th][i] for i in binind],
              yerr=[stdev[ifo][th][i] for i in binind], fmt="+-", 
              label=ifo + " fit above %.2f" % th)
        if opt.bin_spacing == "log": plt.semilogx()
        plt.grid()
        plt.legend(loc="best")
        plt.xlabel(paramname, size="large")
        plt.ylabel(r"fit parameter $\alpha$", size='large')
        plt.savefig(plotbase + '_alpha_vs_' + paramtag[0:3] + '.png')
        plt.close()

        for th in opt.stat_threshold:
            plt.plot(pbins.centres(), [ks_prob[ifo][th][i] for i in binind], 
              '+--', label=ifo+' KS prob, thresh %.2f' % th)
        if opt.bin_spacing == 'log': plt.loglog()
        else : plt.semilogy()
        plt.grid()
        leg = plt.legend(loc='best', labelspacing=0.2)
        plt.setp(leg.get_texts(), fontsize=11)
        plt.xlabel(paramname, size='large')
        plt.ylabel('KS test p-value')
        plt.savefig(plotbase + '_KS_prob_vs_' + paramtag[0:3] + '.png')
        plt.close()

logging.info('Done!')
back to top