#!/usr/bin/python
# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
from __future__ import division
import sys, h5py
import argparse, logging
from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import copy, numpy as np
from scipy.stats import kstest
from pycbc import io, events, pnutils, bin_utils, results
from pycbc.events import trigger_fits as trstats
import pycbc.version
#### DEFINITIONS AND FUNCTIONS ####
stat_dict = {
"new_snr" : events.newsnr,
"effective_snr" : events.effsnr,
"snr" : lambda snr, rchisq : snr,
"snronchi" : lambda snr, rchisq : snr / (rchisq ** 0.5),
"newsnr_sgveto" : events.newsnr_sgveto,
}
def get_stat(statchoice, snr, rchisq, sgchisq, fac):
if fac is not None:
if statchoice not in ['new_snr', 'effective_snr']:
raise RuntimeError("Can't use --stat-factor with this statistic!")
return stat_dict[statchoice](snr, rchisq, fac)
elif statchoice == 'newsnr_sgveto':
return stat_dict[statchoice](snr, rchisq, sgchisq)
else:
return stat_dict[statchoice](snr, rchisq)
#### MAIN ####
parser = argparse.ArgumentParser(usage="",
description="Perform maximum-likelihood fits of single inspiral trigger"
" distributions to various functions")
parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
help="Input hdf5 file containing single triggers. "
"Required")
parser.add_argument("--bank-file", default=None,
help="hdf file containing template parameters. Required")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
help="File(s) in .xml format with veto segments to apply "
"to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
help="Name(s) of veto segments to apply. Optional, if not "
"given all segments for a given ifo will be used")
parser.add_argument("--ifo", required=True,
help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--fit-function",
choices=["exponential", "rayleigh", "power"],
help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
choices=["snr", "snronchi", "effective_snr", "new_snr", "newsnr_sgveto"],
help="Function of SNR and chisq to perform fits with")
parser.add_argument("--stat-factor", type=float,
help="Adjustable magic number used in some sngl "
"statistics. Values commonly used: 6 for new_snr, 250 "
"or 50 for effective_snr")
parser.add_argument("--stat-threshold", nargs="+", type=float,
help="Only fit triggers with statistic value above this "
"threshold : can be a space-separated list, then a fit "
"will be done for each threshold. Required. Typical "
"values 6.25 6.5 6.75")
parser.add_argument("--prune-param",
help="Parameter to define bins for 'pruning' loud triggers"
" to make the fit insensitive to signals and outliers. "
"Choose from mchirp, mtotal, template_duration or a named "
"frequency cutoff in pnutils or a frequency function in "
"LALSimulation")
parser.add_argument("--prune-bins", type=int,
help="Number of bins to divide bank into when pruning")
parser.add_argument("--prune-number", type=int,
help="Number of loudest events to prune in each bin")
parser.add_argument("--log-prune-param", action='store_true',
help="Bin in the log of prune-param")
parser.add_argument("--f-lower", type=float, default=0.,
help="Starting frequency for calculating template "
"duration; if not given, duration will be read from "
"single trigger files")
# FIXME : allow choice of SEOBNRv2/v4 or PhenD duration formula ?
parser.add_argument("--min-duration", default=0.,
help="Fudge factor for templates with tiny or negative "
"values of template_duration: add to duration values "
"before fitting. Units seconds")
parser.add_argument("--bin-param", required=True,
help="Parameter over which to bin when fitting. Required. "
"Choose from mchirp, mtotal, template_duration or a named "
"frequency cutoff in pnutils or a frequency function in "
"LALSimulation")
parser.add_argument("--bin-spacing", choices=["linear", "log"],
help="How to space parameter bin edges")
parser.add_argument("--num-bins", type=int,
help="Number of regularly spaced bins to use over the "
" parameter")
parser.add_argument("--bin-param-units",
help="String to display units of the binning parameter")
outputchoice = parser.add_mutually_exclusive_group()
outputchoice.add_argument("--plot-dir",
help="Plot the fits made, the variation of fitting "
"coefficients and the Kolmogorov-Smirnov test values "
"and save to the specified directory.")
outputchoice.add_argument("--output-file",
help="Output a plot of hists and fits made for a single "
"threshold value.")
parser.add_argument("--user-tag", default="",
help="Put a possibly informative string in the names of "
"plot files")
args = parser.parse_args()
args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])
if len(args.veto_segment_name) != len(args.veto_file):
raise RuntimeError("Number of veto files much match veto file names")
if (args.prune_param or args.prune_bins or args.prune_number) and not \
(args.prune_param and args.prune_bins and args.prune_number):
raise RuntimeError("To prune, need to specify param, number of bins and "
"nonzero number to prune in each bin!")
if args.output_file is not None and len(args.stat_threshold) > 1:
raise RuntimeError("Cannot plot more than one threshold in a single "
"output file!")
if args.verbose:
log_level = logging.DEBUG
else:
log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
statname = "reweighted SNR" if args.sngl_stat == "new_snr" else \
args.sngl_stat.replace("_", " ").replace("snr", "SNR")
paramname = args.bin_param.replace("_", " ")
paramtag = args.bin_param.replace("_", "")
if args.plot_dir:
if not args.plot_dir.endswith('/'):
args.plot_dir += '/'
plotbase = outdir + args.ifo + "-" + args.user_tag
logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')
# get the stat values
minth = min(args.stat_threshold)
snr = trigf[args.ifo+'/snr'][:]
# SNR cut to reduce memory usage
logging.info('Applying SNR cut at %f' % minth)
aboveminth = snr >= minth
snr = snr[aboveminth]
chisq = trigf[args.ifo+'/chisq'][:][aboveminth]
chisq_dof = trigf[args.ifo+'/chisq_dof'][:][aboveminth]
rchisq = chisq / (2 * chisq_dof - 2)
try:
sgchisq = trigf[args.ifo+'/sg_chisq'][:][aboveminth]
except KeyError:
sgchisq = None
del chisq
del chisq_dof
logging.info('Calculating stat values')
stat = get_stat(args.sngl_stat, snr, rchisq, sgchisq, args.stat_factor)
del snr
del rchisq
del sgchisq
# get the duration values if needed
if args.bin_param == 'template_duration' and not args.f_lower:
logging.info('Using template duration from the trigger file')
trig_dur = True
else:
trig_dur = False
# stat threshold to reduce trigger numbers
minth = min(args.stat_threshold)
abovethresh = stat >= minth
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][aboveminth][abovethresh]
time = trigf[args.ifo+'/end_time'][:][aboveminth][abovethresh]
if trig_dur:
tdur = trigf[args.ifo+'/template_duration'][:][aboveminth][abovethresh]
logging.info('%i trigs left after thresholding at %f' % (len(stat), minth))
# now do vetoing
for veto_file, veto_segment_name in zip(args.veto_file, args.veto_segment_name):
retain, junk = events.veto.indices_outside_segments(time, [veto_file],
ifo=args.ifo, segment_name=veto_segment_name)
stat = stat[retain]
tid = tid[retain]
time = time[retain]
if trig_dur:
tdur = tdur[retain]
logging.info('%i trigs left after vetoing with %s' %
(len(stat), args.veto_file))
### Functions for doing the pruning (removal of trigs at loudest times)
def get_masses(args, tid):
m1 = templatef['mass1'][:][tid]
m2 = templatef['mass2'][:][tid]
s1z = templatef['spin1z'][:][tid]
s2z = templatef['spin2z'][:][tid]
return m1, m2, s1z, s2z
def get_param(args, tag, m1, m2, s1z, s2z):
# here used for both pruning and binning params
paramarg = getattr(args, tag+'_param')
try:
# will fail if m1 is a float rather than a sequence
logging.info('Getting %s values for %i triggers' % (paramarg, len(m1)))
except:
pass
if paramarg == 'mchirp':
parvals, _ = pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
elif paramarg == 'mtotal':
parvals = m1 + m2
elif paramarg == 'template_duration':
# will default to SEOBNRv4 duration function
parvals = pnutils.get_imr_duration(m1, m2, s1z, s2z, args.f_lower)
parvals += args.min_duration
elif paramarg in pnutils.named_frequency_cutoffs.keys():
parvals = pnutils.frequency_cutoff_from_name(paramarg, m1, m2, s1z, s2z)
else:
# try asking for a LALSimulation frequency function
parvals = pnutils.get_freq(paramarg, m1, m2, s1z, s2z)
return parvals
# get min and max param values in bank
if args.prune_param:
logging.info('Getting min and max param values')
prpars = get_param(args, 'prune',
templatef['mass1'][:], templatef['mass2'][:],
templatef['spin1z'][:], templatef['spin2z'][:])
minprpar = min(prpars)
maxprpar = max(prpars)
del prpars
logging.info('prune param range %f %f' % (minprpar, maxprpar))
def which_bin(par, minpar, maxpar, nbins, log=False):
"""
Returns the bin index in which the parameter value belongs (from 0 through
nbins-1) when dividing the range between minpar and maxpar equally into
nbins bins.
"""
assert (par >= minpar and par <= maxpar)
if log:
par, minpar, maxpar = np.log(par), np.log(minpar), np.log(maxpar)
# par lies some fraction of the way between min and max
frac = float(par - minpar) / float(maxpar - minpar)
# binind then lies between 0 and nbins - 1
binind = int(frac * nbins)
# except for a corner case
if par == maxpar:
binind = nbins - 1
return binind
if args.prune_param:
# hard-coded time window of 0.1s
args.prune_window = 0.1
# initialize bin storage
prunedtimes = {}
for i in range(args.prune_bins):
prunedtimes[i] = []
# keep a record of the triggers if all successive loudest events were to
# be pruned
statpruneall = copy.deepcopy(stat)
tidpruneall = copy.deepcopy(tid)
timepruneall = copy.deepcopy(time)
# many trials may be required to prune in 'quieter' bins
for j in range(1000):
# are all the bins full already?
numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)])
if numpruned == args.prune_bins * args.prune_number:
logging.info('Finished pruning!')
break
if numpruned > args.prune_bins * args.prune_number:
logging.error('Uh-oh, we pruned too many things .. %i, to be '
'precise' % numpruned)
raise RuntimeError
loudest = np.argmax(statpruneall)
lstat = statpruneall[loudest]
ltid = tidpruneall[loudest]
ltime = timepruneall[loudest]
m1, m2, s1z, s2z = get_masses(args, ltid)
lbin = which_bin(get_param(args, 'prune', m1, m2, s1z, s2z),
minprpar, maxprpar,
args.prune_bins, log=args.log_prune_param)
# is the bin where the loudest trigger lives full already?
if len(prunedtimes[lbin]) == args.prune_number:
logging.info('%i - Bin %i full, not pruning event with stat %f at time '
'%.3f' % (j, lbin, lstat, ltime))
# prune the reference trigger array
retain = abs(timepruneall - ltime) > args.prune_window
statpruneall = statpruneall[retain]
tidpruneall = tidpruneall[retain]
timepruneall = timepruneall[retain]
del retain
continue
else:
logging.info('Pruning event with stat %f at time %.3f in bin %i' %
(lstat, ltime, lbin))
# now do the pruning
retain = abs(time - ltime) > args.prune_window
logging.info('%i trigs before pruning' % len(stat))
stat = stat[retain]
logging.info('%i trigs remain' % len(stat))
tid = tid[retain]
time = time[retain]
if trig_dur:
tdur = tdur[retain]
# also for the reference trig arrays
retain = abs(timepruneall - ltime) > args.prune_window
statpruneall = statpruneall[retain]
tidpruneall = tidpruneall[retain]
timepruneall = timepruneall[retain]
# record the time
prunedtimes[lbin].append(ltime)
del retain
del statpruneall
del tidpruneall
del timepruneall
# get binning params after tuning
if trig_dur:
binpars = tdur + args.min_duration
else:
m1, m2, s1z, s2z = get_masses(args, tid)
binpars = get_param(args, 'bin',
m1, m2, s1z, s2z)
logging.info("Parameter range of triggers: %f - %f" %
(min(binpars), max(binpars)))
# get the bins
# we assume that parvals are all positive
assert min(binpars) >= 0
pmin = 0.999 * min(binpars)
pmax = 1.001 * max(binpars)
if args.bin_spacing == "linear":
pbins = bin_utils.LinearBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "log":
pbins = bin_utils.LogarithmicBins(pmin, pmax, args.num_bins)
# list of bin indices
binind = [pbins[c] for c in pbins.centres()]
logging.info("Assigning trigger param values to bins")
# FIXME: This is slow!! either find a better way of using pylal.rate
# or write faster binning routine
pind = np.array([pbins[par] for par in binpars])
logging.info("Getting max counts in bins")
# determine trigger counts first to get plot limits to make them the same
# for all thresholds; use only the smallest threshold requested
bincounts = []
for i in binind:
vals_inbin = stat[pind == i]
bincounts.append(sum(vals_inbin >= minth))
maxcount = max(bincounts)
plotrange = np.linspace(0.95 * min(stat), 1.05 * max(stat), 100)
# initialize result storage
parbins = {}
counts = {}
templates = {}
alphas = {}
stdev = {}
ks_prob = {}
nabove = {}
histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]
for th in args.stat_threshold:
logging.info("Fitting above threshold %f" % th)
counts[th] = {}
alphas[th] = {}
stdev[th] = {}
ks_prob[th] = {}
if args.output_file:
fig = plt.figure()
for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
# determine number of templates generating the triggers involved
# for hdf5, use the template id; otherwise use masses
tid_inbin = tid[pind == i]
numtmpl = len(set(tid_inbin))
templates[i] = numtmpl
vals_inbin = stat[pind == i]
counts[th][i] = sum(vals_inbin >= th)
if len(vals_inbin) == 0:
logging.info("No trigs in bin %f-%f", (lower, upper))
continue
# do the fit
alpha, sig_alpha = trstats.fit_above_thresh(
args.fit_function, vals_inbin, th)
alphas[th][i] = alpha
stdev[th][i] = sig_alpha
_, ks_prob[th][i] = trstats.KS_test(
args.fit_function, vals_inbin, alpha, th)
# add histogram to plot
histcounts, edges = np.histogram(vals_inbin, bins=50)
cum_counts = histcounts[::-1].cumsum()[::-1]
binlabel = r"%.3g - %.3g" % (lower, upper)
# histogram of fitted values
plt.semilogy(edges[:-1], cum_counts, linewidth=2,
color=histcolors[i], label=binlabel, alpha=0.6)
# fit central value
plt.semilogy(plotrange, counts[th][i] * \
trstats.cum_fit(args.fit_function, plotrange, alpha, th),
"--", color=histcolors[i],
label=r"$\alpha = $%.2f $\pm$ %.2f" % (alpha, sig_alpha))
# 1sigma upper deviation on alpha
plt.semilogy(plotrange, counts[th][i] * \
trstats.cum_fit(args.fit_function, plotrange, alpha + \
sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
# 1sigma lower deviation
plt.semilogy(plotrange, counts[th][i] * \
trstats.cum_fit(args.fit_function, plotrange, alpha - \
sig_alpha, th), ":", alpha=0.6, color=histcolors[i])
# finish the hist plot
leg = plt.legend(labelspacing=0.2)
unitstring = " (%s)" % args.bin_param_units if \
args.bin_param_units is not None else ""
leg.set_title(paramname+unitstring)
plt.setp(leg.get_texts(), fontsize=11)
plt.ylim(0.7, 2*maxcount)
plt.xlim(0.9*minth, 1.1*max(plotrange))
plt.grid()
plt.xlabel(statname, size="large")
plt.ylabel("cumulative number", size="large")
if args.plot_dir:
plt.title(args.ifo + " " + statname + " distribution split by " + \
paramname)
dest = plotbase + "_" + args.sngl_stat + "_cdf_by_" + paramtag[0:3] + \
" _fit_thresh_" + str(th) + ".png"
logging.info("Saving cumhist to %s" % dest)
plt.savefig(dest)
elif args.output_file:
logging.info("Saving cumhist to %s" % args.output_file)
results.save_fig_with_metadata(
fig, args.output_file,
title="%s: %s histogram of single detector triggers" % (args.ifo,
statname),
caption=(r"Histogram of single detector %s values binned by %s "
"with fitted %s distribution parameterized by α"\
% (statname, paramname, args.fit_function)),
cmd=" ".join(sys.argv)
)
plt.close()
# don't make any more plots if only making the single rainbow hist
if args.output_file:
exit()
# make plots of alpha, trig count and KS significance
for th in args.stat_threshold:
plt.errorbar(pbins.centres(), [alphas[th][i] for i in binind],
yerr=[stdev[th][i] for i in binind], fmt="+-",
label=args.ifo + " fit above %.2f" % th)
if args.bin_spacing == "log": plt.semilogx()
plt.xlim(0.03, 150)
plt.ylim(0, 10)
plt.grid()
plt.legend(loc="best")
plt.xlabel(paramname, size="large")
plt.ylabel(r"fit parameter $\alpha$", size='large')
plt.savefig(plotbase + '_alpha_vs_' + paramtag[0:3] + '.png')
plt.close()
for th in args.stat_threshold:
plt.errorbar(pbins.centres(),
[float(counts[th][i])/templates[i] for i in binind],
yerr=[counts[th][i]**0.5/templates[i] for i in binind],
fmt="+-", label=args.ifo + " trigs above %.2f" % th)
if args.bin_spacing == "log": plt.semilogx()
plt.xlim(0.03, 150)
plt.grid()
plt.legend(loc="best")
plt.xlabel(paramname, size="large")
plt.ylabel(r"Triggers above threshold per template", size='large')
plt.savefig(plotbase + '_nabove_vs_' + paramtag[0:3] + '.png')
plt.close()
for th in args.stat_threshold:
plt.plot(pbins.centres(), [ks_prob[th][i] for i in binind],
'+--', label=args.ifo+' KS prob, thresh %.2f' % th)
if args.bin_spacing == 'log':
plt.loglog()
else:
plt.semilogy()
plt.xlim(0.03, 150)
plt.grid()
leg = plt.legend(loc='best', labelspacing=0.2)
plt.setp(leg.get_texts(), fontsize=11)
plt.xlabel(paramname, size='large')
plt.ylabel('KS test p-value')
plt.savefig(plotbase + '_KS_prob_vs_' + paramtag[0:3] + '.png')
plt.close()
logging.info('Done!')