Revision 561c8b0416deb4d51266e7a71002f8c6e93d36e3 authored by Soumi De on 18 May 2017, 13:33:34 UTC, committed by GitHub on 18 May 2017, 13:33:34 UTC


* Allow inj_filter_rejecor template waveforms to be generated starting from f_lower of templates stored in the bank
1 parent 703ba68
Raw File
pycbc_fit_sngls_binned
#!/usr/bin/python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt

import copy, numpy as np

from scipy.stats import kstest

from pycbc import io, events, pnutils, bin_utils, results
from pycbc.events import trigger_fits as trstats
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

stat_dict = {
    "new_snr"       : events.newsnr,
    "effective_snr" : events.effsnr,
    "snr"           : lambda snr, rchisq : snr,
    "snronchi"      : lambda snr, rchisq : snr / (rchisq ** 0.5),
    "newsnr_sgveto"      : events.newsnr_sgveto,
}

def get_stat(statchoice, snr, rchisq, sgchisq, fac):
    if fac is not None:
        if statchoice not in ['new_snr', 'effective_snr']:
            raise RuntimeError("Can't use --stat-factor with this statistic!")
        return stat_dict[statchoice](snr, rchisq, fac)
    elif statchoice == 'newsnr_sgveto':
        return stat_dict[statchoice](snr, rchisq, sgchisq)
    else:
        return stat_dict[statchoice](snr, rchisq)

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                " distributions to various functions")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
                    help="Input hdf5 file containing single triggers. "
                    "Required")
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=["snr", "snronchi", "effective_snr", "new_snr", "newsnr_sgveto"],
                    help="Function of SNR and chisq to perform fits with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics. Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
parser.add_argument("--stat-threshold", nargs="+", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold : can be a space-separated list, then a fit "
                    "will be done for each threshold.  Required.  Typical "
                    "values 6.25 6.5 6.75")
parser.add_argument("--prune-param",
                    help="Parameter to define bins for 'pruning' loud triggers"
                    " to make the fit insensitive to signals and outliers. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--prune-bins", type=int,
                    help="Number of bins to divide bank into when pruning")
parser.add_argument("--prune-number", type=int,
                    help="Number of loudest events to prune in each bin")
parser.add_argument("--log-prune-param", action='store_true',
                    help="Bin in the log of prune-param")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                    "duration; if not given, duration will be read from "
                    "single trigger files")
# FIXME : allow choice of SEOBNRv2/v4 or PhenD duration formula ?
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before fitting. Units seconds")
parser.add_argument("--bin-param", required=True,
                    help="Parameter over which to bin when fitting. Required. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--bin-spacing", choices=["linear", "log"],
                    help="How to space parameter bin edges")
parser.add_argument("--num-bins", type=int,
                    help="Number of regularly spaced bins to use over the "
                    " parameter")
parser.add_argument("--bin-param-units",
                    help="String to display units of the binning parameter")
outputchoice = parser.add_mutually_exclusive_group()
outputchoice.add_argument("--plot-dir",
                    help="Plot the fits made, the variation of fitting "
                    "coefficients and the Kolmogorov-Smirnov test values "
                    "and save to the specified directory.")
outputchoice.add_argument("--output-file",
                    help="Output a plot of hists and fits made for a single "
                    "threshold value.")
parser.add_argument("--user-tag", default="",
                    help="Put a possibly informative string in the names of "
                    "plot files")

args = parser.parse_args()

args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])

if len(args.veto_segment_name) != len(args.veto_file):
    raise RuntimeError("Number of veto files much match veto file names")

if (args.prune_param or args.prune_bins or args.prune_number) and not \
   (args.prune_param and args.prune_bins and args.prune_number):
    raise RuntimeError("To prune, need to specify param, number of bins and "
                       "nonzero number to prune in each bin!")

if args.output_file is not None and len(args.stat_threshold) > 1:
    raise RuntimeError("Cannot plot more than one threshold in a single "
                       "output file!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

statname = "reweighted SNR" if args.sngl_stat == "new_snr" else \
           args.sngl_stat.replace("_", " ").replace("snr", "SNR")
paramname = args.bin_param.replace("_", " ")
paramtag = args.bin_param.replace("_", "")
if args.plot_dir:
    if not args.plot_dir.endswith('/'):
        args.plot_dir += '/'
    plotbase = outdir + args.ifo + "-" + args.user_tag

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')

# get the stat values
minth = min(args.stat_threshold)
snr = trigf[args.ifo+'/snr'][:]
# SNR cut to reduce memory usage
logging.info('Applying SNR cut at %f' % minth)
aboveminth = snr >= minth
snr = snr[aboveminth]
chisq = trigf[args.ifo+'/chisq'][:][aboveminth]
chisq_dof = trigf[args.ifo+'/chisq_dof'][:][aboveminth]
rchisq = chisq / (2 * chisq_dof - 2)

try:
    sgchisq = trigf[args.ifo+'/sg_chisq'][:][aboveminth]
except KeyError:
    sgchisq = None

del chisq
del chisq_dof
logging.info('Calculating stat values')
stat = get_stat(args.sngl_stat, snr, rchisq, sgchisq, args.stat_factor)
del snr
del rchisq
del sgchisq

# get the duration values if needed
if args.bin_param == 'template_duration' and not args.f_lower:
    logging.info('Using template duration from the trigger file')
    trig_dur = True
else:
    trig_dur = False

# stat threshold to reduce trigger numbers
minth = min(args.stat_threshold)
abovethresh = stat >= minth
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][aboveminth][abovethresh]
time = trigf[args.ifo+'/end_time'][:][aboveminth][abovethresh]
if trig_dur:
    tdur = trigf[args.ifo+'/template_duration'][:][aboveminth][abovethresh]
logging.info('%i trigs left after thresholding at %f' % (len(stat), minth))

# now do vetoing
for veto_file, veto_segment_name in zip(args.veto_file, args.veto_segment_name):
    retain, junk = events.veto.indices_outside_segments(time, [veto_file],
                             ifo=args.ifo, segment_name=veto_segment_name)
    stat = stat[retain]
    tid = tid[retain]
    time = time[retain]
    if trig_dur:
        tdur = tdur[retain]
    logging.info('%i trigs left after vetoing with %s' %
                                                   (len(stat), args.veto_file))

### Functions for doing the pruning (removal of trigs at loudest times)

def get_masses(args, tid):
    m1 = templatef['mass1'][:][tid]
    m2 = templatef['mass2'][:][tid]
    s1z = templatef['spin1z'][:][tid]
    s2z = templatef['spin2z'][:][tid]
    return m1, m2, s1z, s2z

def get_param(args, tag, m1, m2, s1z, s2z):
    # here used for both pruning and binning params
    paramarg = getattr(args, tag+'_param')
    try:
        # will fail if m1 is a float rather than a sequence
        logging.info('Getting %s values for %i triggers' % (paramarg, len(m1)))
    except:
        pass
    if paramarg == 'mchirp':
        parvals, _ = pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
    elif paramarg == 'mtotal':
        parvals = m1 + m2
    elif paramarg == 'template_duration':
        # will default to SEOBNRv4 duration function
        parvals = pnutils.get_imr_duration(m1, m2, s1z, s2z, args.f_lower)
        parvals += args.min_duration
    elif paramarg in pnutils.named_frequency_cutoffs.keys():
        parvals = pnutils.frequency_cutoff_from_name(paramarg, m1, m2, s1z, s2z)
    else:
        # try asking for a LALSimulation frequency function
        parvals = pnutils.get_freq(paramarg, m1, m2, s1z, s2z)
    return parvals

# get min and max param values in bank
if args.prune_param:
    logging.info('Getting min and max param values')
    prpars = get_param(args, 'prune',
                       templatef['mass1'][:], templatef['mass2'][:],
                       templatef['spin1z'][:], templatef['spin2z'][:])
    minprpar = min(prpars)
    maxprpar = max(prpars)
    del prpars
    logging.info('prune param range %f %f' % (minprpar, maxprpar))

def which_bin(par, minpar, maxpar, nbins, log=False):
    """
    Returns the bin index in which the parameter value belongs (from 0 through
    nbins-1) when dividing the range between minpar and maxpar equally into
    nbins bins.
    """
    assert (par >= minpar and par <= maxpar)
    if log:
        par, minpar, maxpar = np.log(par), np.log(minpar), np.log(maxpar)
    # par lies some fraction of the way between min and max
    frac = float(par - minpar) / float(maxpar - minpar)
    # binind then lies between 0 and nbins - 1
    binind = int(frac * nbins)
    # except for a corner case
    if par == maxpar:
        binind = nbins - 1
    return binind

if args.prune_param:
    # hard-coded time window of 0.1s
    args.prune_window = 0.1
    # initialize bin storage
    prunedtimes = {}
    for i in range(args.prune_bins):
        prunedtimes[i] = []

    # keep a record of the triggers if all successive loudest events were to
    # be pruned
    statpruneall = copy.deepcopy(stat)
    tidpruneall = copy.deepcopy(tid)
    timepruneall = copy.deepcopy(time)

    # many trials may be required to prune in 'quieter' bins
    for j in range(1000):
        # are all the bins full already?
        numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)])
        if numpruned == args.prune_bins * args.prune_number:
            logging.info('Finished pruning!')
            break
        if numpruned > args.prune_bins * args.prune_number:
            logging.error('Uh-oh, we pruned too many things .. %i, to be '
                          'precise' % numpruned)
            raise RuntimeError
        loudest = np.argmax(statpruneall)
        lstat = statpruneall[loudest]
        ltid = tidpruneall[loudest]
        ltime = timepruneall[loudest]
        m1, m2, s1z, s2z = get_masses(args, ltid)
        lbin = which_bin(get_param(args, 'prune', m1, m2, s1z, s2z),
                         minprpar, maxprpar,
                         args.prune_bins, log=args.log_prune_param)
        # is the bin where the loudest trigger lives full already?
        if len(prunedtimes[lbin]) == args.prune_number:
            logging.info('%i - Bin %i full, not pruning event with stat %f at time '
                         '%.3f' % (j, lbin, lstat, ltime))
            # prune the reference trigger array
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            del retain
            continue
        else:
            logging.info('Pruning event with stat %f at time %.3f in bin %i' %
                         (lstat, ltime, lbin))
            # now do the pruning
            retain = abs(time - ltime) > args.prune_window
            logging.info('%i trigs before pruning' % len(stat))
            stat = stat[retain]
            logging.info('%i trigs remain' % len(stat))
            tid = tid[retain]
            time = time[retain]
            if trig_dur:
                tdur = tdur[retain]
            # also for the reference trig arrays
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            # record the time
            prunedtimes[lbin].append(ltime)
            del retain
    del statpruneall
    del tidpruneall
    del timepruneall

# get binning params after tuning
if trig_dur:
    binpars = tdur + args.min_duration
else:
    m1, m2, s1z, s2z = get_masses(args, tid)
    binpars = get_param(args, 'bin',
                        m1, m2, s1z, s2z)
logging.info("Parameter range of triggers: %f - %f" %
                                                  (min(binpars), max(binpars)))
# get the bins
# we assume that parvals are all positive
assert min(binpars) >= 0
pmin = 0.999 * min(binpars)
pmax = 1.001 * max(binpars)
if args.bin_spacing == "linear":
    pbins = bin_utils.LinearBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "log":
    pbins = bin_utils.LogarithmicBins(pmin, pmax, args.num_bins)
# list of bin indices
binind = [pbins[c] for c in pbins.centres()]
logging.info("Assigning trigger param values to bins")
# FIXME: This is slow!! either find a better way of using pylal.rate
# or write faster binning routine
pind = np.array([pbins[par] for par in binpars])

logging.info("Getting max counts in bins")
# determine trigger counts first to get plot limits to make them the same
# for all thresholds; use only the smallest threshold requested
bincounts = []
for i in binind:
    vals_inbin = stat[pind == i]
    bincounts.append(sum(vals_inbin >= minth))
maxcount = max(bincounts)
plotrange = np.linspace(0.95 * min(stat), 1.05 * max(stat), 100)

# initialize result storage
parbins = {}
counts = {}
templates = {}
alphas = {}
stdev = {}
ks_prob = {}
nabove = {}

histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]

for th in args.stat_threshold:
    logging.info("Fitting above threshold %f" % th)
    counts[th] = {}
    alphas[th] = {}
    stdev[th] = {}
    ks_prob[th] = {}

    if args.output_file:
        fig = plt.figure()
    for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
        # determine number of templates generating the triggers involved
        # for hdf5, use the template id; otherwise use masses
        tid_inbin = tid[pind == i]
        numtmpl = len(set(tid_inbin))
        templates[i] = numtmpl
        vals_inbin = stat[pind == i]
        counts[th][i] = sum(vals_inbin >= th)
        if len(vals_inbin) == 0:
            logging.info("No trigs in bin %f-%f", (lower, upper))
            continue
        # do the fit
        alpha, sig_alpha = trstats.fit_above_thresh(
                                             args.fit_function, vals_inbin, th)
        alphas[th][i] = alpha
        stdev[th][i] = sig_alpha
        _, ks_prob[th][i] = trstats.KS_test(
                                      args.fit_function, vals_inbin, alpha, th)
        # add histogram to plot
        histcounts, edges = np.histogram(vals_inbin, bins=50)
        cum_counts = histcounts[::-1].cumsum()[::-1]
        binlabel = r"%.3g - %.3g" % (lower, upper)
        # histogram of fitted values
        plt.semilogy(edges[:-1], cum_counts, linewidth=2,
                     color=histcolors[i], label=binlabel, alpha=0.6)
        # fit central value
        plt.semilogy(plotrange, counts[th][i] * \
                     trstats.cum_fit(args.fit_function, plotrange, alpha, th),
                     "--", color=histcolors[i],
                     label=r"$\alpha = $%.2f $\pm$ %.2f" % (alpha, sig_alpha))
        # 1sigma upper deviation on alpha
        plt.semilogy(plotrange, counts[th][i] * \
                     trstats.cum_fit(args.fit_function, plotrange, alpha + \
                     sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
        # 1sigma lower deviation
        plt.semilogy(plotrange, counts[th][i] * \
                     trstats.cum_fit(args.fit_function, plotrange, alpha - \
                     sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
    # finish the hist plot
    leg = plt.legend(labelspacing=0.2)
    unitstring = " (%s)" % args.bin_param_units if \
                                       args.bin_param_units is not None else ""
    leg.set_title(paramname+unitstring)
    plt.setp(leg.get_texts(), fontsize=11)
    plt.ylim(0.7, 2*maxcount)
    plt.xlim(0.9*minth, 1.1*max(plotrange))
    plt.grid()
    plt.xlabel(statname, size="large")
    plt.ylabel("cumulative number", size="large")
    if args.plot_dir:
        plt.title(args.ifo + " " + statname + " distribution split by " + \
              paramname)
        dest = plotbase + "_" + args.sngl_stat + "_cdf_by_" + paramtag[0:3] + \
                                            " _fit_thresh_" + str(th) + ".png"
        logging.info("Saving cumhist to %s" % dest)
        plt.savefig(dest)
    elif args.output_file:
        logging.info("Saving cumhist to %s" % args.output_file)
        results.save_fig_with_metadata(
            fig, args.output_file,
            title="%s: %s histogram of single detector triggers" % (args.ifo,
                  statname),
            caption=(r"Histogram of single detector %s values binned by %s "
                     "with fitted %s distribution parameterized by &alpha;"\
                     % (statname, paramname, args.fit_function)),
            cmd=" ".join(sys.argv)
        )
    plt.close()

# don't make any more plots if only making the single rainbow hist
if args.output_file:
    exit()

# make plots of alpha, trig count and KS significance
for th in args.stat_threshold:
    plt.errorbar(pbins.centres(), [alphas[th][i] for i in binind],
                 yerr=[stdev[th][i] for i in binind], fmt="+-",
                 label=args.ifo + " fit above %.2f" % th)
if args.bin_spacing == "log": plt.semilogx()
plt.xlim(0.03, 150)
plt.ylim(0, 10)
plt.grid()
plt.legend(loc="best")
plt.xlabel(paramname, size="large")
plt.ylabel(r"fit parameter $\alpha$", size='large')
plt.savefig(plotbase + '_alpha_vs_' + paramtag[0:3] + '.png')
plt.close()

for th in args.stat_threshold:
    plt.errorbar(pbins.centres(),
                 [float(counts[th][i])/templates[i] for i in binind],
                 yerr=[counts[th][i]**0.5/templates[i] for i in binind],
                 fmt="+-", label=args.ifo + " trigs above %.2f" % th)
if args.bin_spacing == "log": plt.semilogx()
plt.xlim(0.03, 150)
plt.grid()
plt.legend(loc="best")
plt.xlabel(paramname, size="large")
plt.ylabel(r"Triggers above threshold per template", size='large')
plt.savefig(plotbase + '_nabove_vs_' + paramtag[0:3] + '.png')
plt.close()

for th in args.stat_threshold:
    plt.plot(pbins.centres(), [ks_prob[th][i] for i in binind],
          '+--', label=args.ifo+' KS prob, thresh %.2f' % th)
if args.bin_spacing == 'log':
    plt.loglog()
else:
    plt.semilogy()
plt.xlim(0.03, 150)
plt.grid()
leg = plt.legend(loc='best', labelspacing=0.2)
plt.setp(leg.get_texts(), fontsize=11)
plt.xlabel(paramname, size='large')
plt.ylabel('KS test p-value')
plt.savefig(plotbase + '_KS_prob_vs_' + paramtag[0:3] + '.png')
plt.close()

logging.info('Done!')
back to top