Revision 1457b0e7bf88393f77faeb4d48d944830e804c6a authored by Duncan Brown on 18 September 2016, 00:40:38 UTC, committed by samantha-usman on 18 September 2016, 00:40:38 UTC
* comment out rom check as the checksums are incorrect

* fixed lalapps build

* disable the parts of lalapps that can cause build failures

* fixed rom checksums

* removed hdf5 lib install needed for centos 5

* move rom checks to after r7 is downloaded

* put lal extra data in src directory

* force a link of lalapps against zlib

* when building dev install glue and pylal from github
1 parent 5778e26
Raw File
pycbc_fit_sngl_trigs
#!/usr/bin/python

# Copyright 2015 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import kstest

from pylal import SnglInspiralUtils as sniuls, trigger_fits as trstats, rate

from pycbc import io, events
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

stat_dict = {
    "new_snr"       : events.newsnr,
    "effective_snr" : events.effsnr,
    "snr"           : lambda snr, rchisq : snr,
    "snronchi"      : lambda snr, rchisq : snr / (rchisq ** 0.5)
}

def get_stat(statchoice, snr, rchisq, fac):
    if fac is not None:
        if statchoice not in ["new_snr", "effective_snr"]:
            raise RuntimeError("Can't use --stat-factor with this statistic!")
        return stat_dict[statchoice](snr, rchisq, fac)
    else:
        return stat_dict[statchoice](snr, rchisq)

def get_bins(opt, pmin, pmax):
    if opt.bin_spacing == "linear":
        return rate.LinearBins(pmin, pmax, opt.num_bins)
    elif opt.bin_spacing == "log":
        return rate.LogarithmicBins(pmin, pmax, opt.num_bins)
    elif opt.bin_spacing == "irregular":
        return rate.IrregularBins(opt.bin_edges)

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                "distributions to various functions")

parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--inputs", nargs="+",
                    help="Input file or space-separated list of input files "
                    "containing single triggers.  Currently .xml(.gz), "
                    ".sqlite and .hdf supported.  Required")
parser.add_argument("--template-file", default=None,
                    help="hdf file containing template parameters; required "
                    "if fitting over trigger masses and/or spins with hdf "
                    "format triggers")
parser.add_argument("--veto-file", default=None,
                    help="File in .xml format with veto segments to apply to "
                    "triggers before fitting")
parser.add_argument("--veto-segment-name", default=None,
                    help="Name of veto segments to apply. Optional: if not "
                    "given, all segments for a given ifo will be used")
parser.add_argument("--output", required=True,
                    help="Location for output file containing fit coefficients"
                    ".  Required")
parser.add_argument("--plot-dir", default=None,
                    help="Plot the fits made, the variation of fitting "
                    "coefficients and the Kolmogorov-Smirnov test values "
                    "and save to the specified directory")
parser.add_argument("--user-tag", default="",
                    help="Put a possibly informative string in the names of "
                    "plot files")
parser.add_argument("--ifos", nargs="+",
                    help="Ifo or space-separated list of ifos to select "
                    "triggers to be fit.  Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=["snr", "snronchi", "effective_snr", "new_snr"],
                    help="Function of SNR and chisq to perform fits with")
# FIXME: how to parse various possible mixtures of chisq
#parser.add_argument("--chisq-choice", default="power",
#                    choices=["power","bank","auto","maxpowerbank"],
#                    help="Which chisq value or values to form the "
#                    "single-trigger statistic with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics.  Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
default_thresh = 5.5
cuts = parser.add_mutually_exclusive_group()
cuts.add_argument("--snr-threshold", type=float,
                  help="Only fit triggers with SNR above this threshold. "
                  "Default %.1f" % default_thresh)
cuts.add_argument("--trig-filter", default=None,
                  help="Filter function applied to trigger properties: only "
                  "implemented for hdf format.  May include names from math, "
                  "numpy as 'np', pycbc.events and pycbc.pnutils.  Ex. "
                  "'events.newsnr(self.snr, self.bank_chisq/self.bank_chisq_"
                  "dof) > 6.)'")
parser.add_argument("--stat-threshold", nargs="+", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold : can be a space-separated list, then a fit "
                    "will be done for each threshold.  Required.  Typical "
                    "values 6.5 6.75 7")
parser.add_argument("--fit-param", required=True,
                    help="Parameter over which to estimate variation of the "
                    "fits. Required. Must be a SnglInspiralTable column for "
                    ".xml or .sqlite input or the name of a dataset for .hdf")
# FIXME: allow for math functions of columns. Ex. 1./mtotal")
parser.add_argument("--bin-spacing", choices=["linear", "log", "irregular"],
                    help="How to space parameter bin edges")
binopt = parser.add_mutually_exclusive_group(required=True)
binopt.add_argument("--num-bins", type=int,
                    help="Number of regularly spaced bins to use over the "
                    " parameter")
binopt.add_argument("--irregular-bins",
                    help="Comma-separated list of parameter bin edges. "
                    "Required if --bin-spacing = irregular")

opt = parser.parse_args()

if opt.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

statname = opt.sngl_stat.replace("_", " ")
paramname = opt.fit_param.replace("_", " ")
paramtag = opt.fit_param.replace("_", "")

if opt.plot_dir is not None:
    outdir = opt.plot_dir if opt.plot_dir.endswith('/') else opt.plot_dir+'/'

outfile = open(opt.output, 'w')

## Check option logic
if opt.bin_spacing == "irregular":
    if opt.irregular_bins is None:
        raise RuntimeError("Must specify a list of irregular bin edges!")
    else:
        opt.bin_edges = [float(b) for b in opt.irregular_bins.split(',')]

if opt.inputs[0].split('.')[-1] == "sqlite":
    trigformat = "sqlite"
    table = "sngl_inspiral"
    if len(opt.inputs) > 1:
        raise RuntimeError("Sorry, can't deal with more than one .sqlite "
                           "input file")
    cols = ['ifo', 'end_time', 'end_time_ns', 'snr', 'chisq', 'chisq_dof',
            'bank_chisq', 'bank_chisq_dof', 'cont_chisq', 'cont_chisq_dof',
            'mchirp', 'eta', 'mtotal', 'mass1', 'mass2', 'template_duration']
    if opt.fit_param not in cols:
        cols.append(opt.fit_param)
elif opt.inputs[0].split('.')[-1] == "xml" or \
       opt.inputs[0].split('.')[-2:] == ["xml", "gz"]:
    trigformat = "xml"
elif "hdf" in opt.inputs[0]:
    trigformat = "hdf"
    # columns for reading triggers without template file info
    cols = ["snr", "chisq", "template_duration", "template_id", "end_time"]
else:
    print 'inputs:', opt.inputs
    raise RuntimeError("I didn't recognize the file format for the inputs!")

if opt.trig_filter and trigformat != "hdf":
    raise NotImplementedError("Trigger filtering not available for xml and sqlite!")
# hack default values of SNR threshold and filter function
if not opt.snr_threshold and not opt.trig_filter:
    opt.snr_threshold = default_thresh
    filter_func = 'self.snr > %f' % opt.snr_threshold
elif opt.snr_threshold:
    filter_func = 'self.snr > %f' % opt.snr_threshold
else:
    filter_func = opt.trig_filter

# initialize result storage
parbins = {}
counts = {}
templates = {}
fits = {}
stdev = {}
ks_prob = {}

histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]

# header
outfile.write("# ifo threshold lower upper templates triggers alpha sig_alpha ks_prob\n")

done_vetoing = False
for ifo in opt.ifos:
    # FIXME: Would like a uniform input interface from trigger files
    # to the 'get_column' syntax used by SnglInspiralUtils
    if trigformat == "sqlite":
        sngls = io.sqlite.TableData(table, cols)
        sngls.connect_db(opt.inputs[0])
        sngls.restrict_rows(conditions=["ifo='"+str(ifo)+"'",
            "snr>"+str(opt.snr_threshold)])
        sngls.retrieve_rows()
    elif trigformat == "xml":
        sngls = sniuls.ReadSnglInspiralFromFiles(opt.inputs,
            filterFunc=lambda s: s.snr>opt.snr_threshold,
            verbose=opt.verbose)
        sngls.ifocut(ifo, inplace=True)
    elif trigformat == "hdf":
        if opt.template_file:
            # io function can currently only handle 1 file
            sngls = io.hdf.SingleDetTriggers(opt.inputs[0], 
                  opt.template_file, opt.veto_file, opt.veto_segment_name,
                  filter_func, ifo)
            done_vetoing = True
        else:
            sngls = io.hdf.DataFromFiles(opt.inputs, group=ifo,
                  columnlist=cols,
                  filter_func=filter_func)
    if opt.plot_dir:
        plotbase = outdir + ifo + "-" + opt.user_tag

    snr = sngls.get_column("snr")
    if not len(snr):
        logging.warn("no trigs, skipping %s" % ifo)
        continue
    chisq = sngls.get_column("chisq")
    chisq_dof = sngls.get_column("chisq_dof")
    # template_id or mass1,mass2 are used to count templates
    if trigformat == "hdf":
        tid = sngls.get_column("template_id")
        mass1 = None
        mass2 = None
    else:
        tid = None
        mass1 = sngls.get_column("mass1")
        mass2 = sngls.get_column("mass2")
    parvals = sngls.get_column(opt.fit_param)
    rchisq = chisq / (2 * chisq_dof - 2)
    #bank_rchisq = sngls.get_column('bank_chisq') / sngls.get_column('bank_chisq_dof')
    #rchisq = np.maximum(rchisq, bank_rchisq)

    if opt.veto_file and not done_vetoing:
        logging.info("Applying vetoes from %s for ifo %s" % 
                     (opt.veto_file, ifo))
        keep_idx, segs = events.veto.indices_outside_segments(
                         sngls.get_column("end_time").astype(int), 
                         [opt.veto_file], ifo=ifo,
                         segment_name=opt.veto_segment_name)
        snr = snr[keep_idx]
        rchisq = rchisq[keep_idx]
        if mass1 is not None: mass1 = mass1[keep_idx]
        if mass2 is not None: mass2 = mass2[keep_idx]
        if tid is not None: tid = tid[keep_idx]
        parvals = parvals[keep_idx]
        logging.info("%i trigs remain" % len(snr))

    logging.info("Calculating ranking statistic values")
    statv = get_stat(opt.sngl_stat, snr, rchisq, opt.stat_factor)

    logging.info("Parameter range of triggers: %f - %f" %
                                                  (min(parvals), max(parvals)))
    if opt.bin_spacing == "irregular":
        logging.info("Removing triggers outside bin range %f - %f" % 
                                      (min(opt.bin_edges), max(opt.bin_edges)))
        in_range = np.logical_and(parvals >= min(opt.bin_edges), 
                                  parvals <= max(opt.bin_edges))
        statv = statv[in_range]
        parvals = parvals[in_range]
        logging.info("%i remain" % len(statv))
    pbins = get_bins(opt, 0.99 * min(parvals), 1.01 * max(parvals))
    # list of bin indices
    binind = [pbins[c] for c in pbins.centres()]
    logging.info("Assigning trigger param values to bins")
    # FIXME: This is slow!! either find a better way of using pylal.rate
    # or write faster binning routime 
    pind = np.array([pbins[par] for par in parvals])

    logging.info("Getting max counts in bins")
    # determine trigger counts first to get plot limits to make them the same
    # for all thresholds; use only the smallest threshold requested
    minth = min(opt.stat_threshold)
    bincounts = []
    for i in binind:
        vals_inbin = statv[pind == i]
        bincounts.append(sum(vals_inbin >= minth))
    maxcount = max(bincounts)
    plotrange = np.linspace(0.95 * min(statv), 1.05 * max(statv), 100)

    templates[ifo] = {}
    counts[ifo] = {}
    fits[ifo] = {}
    stdev[ifo] = {}
    ks_prob[ifo] = {}

    for th in opt.stat_threshold:
        logging.info("Fitting above threshold %f" % th)
        counts[ifo][th] = {}
        fits[ifo][th] = {}
        stdev[ifo][th] = {}
        ks_prob[ifo][th] = {}

        for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
            # determine number of templates generating the triggers involved
            # for hdf5, use the template id; otherwise use masses
            if trigformat == "hdf":
                tid_inbin = tid[pind == i]
                numtmpl = len(set(tid_inbin))
            else:
                mass1_inbin = mass1[pind == i]
                mass2_inbin = mass2[pind == i]
                mass_tuples = [(m1, m2) for m1, m2 in 
                                                 zip(mass1_inbin, mass2_inbin)]
                numtmpl = len(set(mass_tuples))
            templates[ifo][i] = numtmpl
            vals_inbin = statv[pind == i]
            counts[ifo][th][i] = sum(vals_inbin >= th)
            if len(vals_inbin) == 0:
                print "No trigs in bin", lower, upper
                continue
            # FIXME: Also allow to use dynamically determined threshold
            # calculated via Nth loudest trig
            #thresh = trstats.tail_threshold(vals_inbin, N=nmax)
            #counts[ifo][th][i] = nmax
            # do the fit
            alpha, sig_alpha = trstats.fit_above_thresh(
                                              opt.fit_function, vals_inbin, th)
            #alpha, sig_alpha = trstats.fit_above_thresh(
            #                             opt.fit_function, vals_inbin, thresh)
            fits[ifo][th][i] = alpha
            stdev[ifo][th][i] = sig_alpha
            _, ks_prob[ifo][th][i] = trstats.KS_test(
                                       opt.fit_function, vals_inbin, alpha, th)
            outfile.write("%s %.2f %.3g %.3g %d %d %.3f %.3f %.3g\n" % 
              (ifo, th, lower, upper, numtmpl, counts[ifo][th][i], alpha,
               sig_alpha, ks_prob[ifo][th][i]))

            # add histogram to plot
            if opt.plot_dir:
                histcounts, edges = np.histogram(vals_inbin, bins=50)
                cum_counts = histcounts[::-1].cumsum()[::-1]
                binlabel = r"%.2g - %.2g" % (lower, upper)
                plt.semilogy(edges[:-1], cum_counts, linewidth=2, 
                             color=histcolors[i], label=binlabel, alpha=0.6)

                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha, th),
                  "--", color=histcolors[i],
                  label=r"$\alpha = $%.2f $\pm$ %.2f" % (alpha, sig_alpha))
                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha + \
                  sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
                plt.semilogy(plotrange, counts[ifo][th][i] * \
                  trstats.cum_fit(opt.fit_function, plotrange, alpha - \
                  sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
 
        if opt.plot_dir:
            leg = plt.legend(labelspacing=0.2)
            plt.setp(leg.get_texts(), fontsize=11)
            plt.ylim(0.7, 2*maxcount)
            plt.xlim(0.9*min(opt.stat_threshold), 1.1*max(plotrange))
            plt.grid()
            plt.title(ifo + " " + statname + " distribution split by " + \
                  paramname)
            plt.xlabel(statname, size="large")
            plt.ylabel("Cumulative number", size="large")
            dest = plotbase + "_" + opt.sngl_stat + "_cdf_by_" + \
              paramtag[0:3] + "_fit_thresh_" + str(th) + ".png"
            logging.info("Saving plot to %s" % dest)
            plt.savefig(dest)
            plt.close()

# make plots of alpha and KS significance for ifos having triggers
if opt.plot_dir:
    for ifo in fits.keys():
        for th in opt.stat_threshold:
            plt.errorbar(pbins.centres(), [fits[ifo][th][i] for i in binind],
              yerr=[stdev[ifo][th][i] for i in binind], fmt="+-", 
              label=ifo + " fit above %.2f" % th)
        if opt.bin_spacing == "log": plt.semilogx()
        plt.grid()
        plt.legend(loc="best")
        plt.xlabel(paramname, size="large")
        plt.ylabel(r"fit parameter $\alpha$", size='large')
        plt.savefig(plotbase + '_alpha_vs_' + paramtag[0:3] + '.png')
        plt.close()

        for th in opt.stat_threshold:
            plt.plot(pbins.centres(), [ks_prob[ifo][th][i] for i in binind], 
              '+--', label=ifo+' KS prob, thresh %.2f' % th)
        if opt.bin_spacing == 'log': plt.loglog()
        else : plt.semilogy()
        plt.grid()
        leg = plt.legend(loc='best', labelspacing=0.2)
        plt.setp(leg.get_texts(), fontsize=11)
        plt.xlabel(paramname, size='large')
        plt.ylabel('KS test p-value')
        plt.savefig(plotbase + '_KS_prob_vs_' + paramtag[0:3] + '.png')
        plt.close()

logging.info('Done!')
back to top