#!/usr/bin/python # Copyright 2016 Thomas Dent # # This program is free software; you can redistribute it and/or modify it # under the terms of the GNU General Public License as published by the # Free Software Foundation; either version 3 of the License, or (at your # option) any later version. # # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General # Public License for more details. from __future__ import division import sys, h5py import argparse, logging from matplotlib import use use('Agg') from matplotlib import pyplot as plt import copy, numpy as np from scipy.stats import kstest from pycbc import io, events, pnutils, bin_utils, results from pycbc.events import trigger_fits as trstats import pycbc.version #### DEFINITIONS AND FUNCTIONS #### stat_dict = { "new_snr" : events.newsnr, "effective_snr" : events.effsnr, "snr" : lambda snr, rchisq : snr, "snronchi" : lambda snr, rchisq : snr / (rchisq ** 0.5), "newsnr_sgveto" : events.newsnr_sgveto, } def get_stat(statchoice, snr, rchisq, sgchisq, fac): if fac is not None: if statchoice not in ['new_snr', 'effective_snr']: raise RuntimeError("Can't use --stat-factor with this statistic!") return stat_dict[statchoice](snr, rchisq, fac) elif statchoice == 'newsnr_sgveto': return stat_dict[statchoice](snr, rchisq, sgchisq) else: return stat_dict[statchoice](snr, rchisq) #### MAIN #### parser = argparse.ArgumentParser(usage="", description="Perform maximum-likelihood fits of single inspiral trigger" " distributions to various functions") parser.add_argument("--version", action=pycbc.version.Version) parser.add_argument("-V", "--verbose", action="store_true", help="Print extra debugging information", default=False) parser.add_argument("--trigger-file", help="Input hdf5 file containing single triggers. " "Required") parser.add_argument("--bank-file", default=None, help="hdf file containing template parameters. Required") parser.add_argument("--veto-file", nargs='*', default=[], action='append', help="File(s) in .xml format with veto segments to apply " "to triggers before fitting") parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append', help="Name(s) of veto segments to apply. Optional, if not " "given all segments for a given ifo will be used") parser.add_argument("--ifo", required=True, help="Ifo producing triggers to be fitted. Required") parser.add_argument("--fit-function", choices=["exponential", "rayleigh", "power"], help="Functional form for the maximum likelihood fit") parser.add_argument("--sngl-stat", default="new_snr", choices=["snr", "snronchi", "effective_snr", "new_snr", "newsnr_sgveto"], help="Function of SNR and chisq to perform fits with") parser.add_argument("--stat-factor", type=float, help="Adjustable magic number used in some sngl " "statistics. Values commonly used: 6 for new_snr, 250 " "or 50 for effective_snr") parser.add_argument("--stat-threshold", nargs="+", type=float, help="Only fit triggers with statistic value above this " "threshold : can be a space-separated list, then a fit " "will be done for each threshold. Required. Typical " "values 6.25 6.5 6.75") parser.add_argument("--prune-param", help="Parameter to define bins for 'pruning' loud triggers" " to make the fit insensitive to signals and outliers. " "Choose from mchirp, mtotal, template_duration or a named " "frequency cutoff in pnutils or a frequency function in " "LALSimulation") parser.add_argument("--prune-bins", type=int, help="Number of bins to divide bank into when pruning") parser.add_argument("--prune-number", type=int, help="Number of loudest events to prune in each bin") parser.add_argument("--log-prune-param", action='store_true', help="Bin in the log of prune-param") parser.add_argument("--f-lower", type=float, default=0., help="Starting frequency for calculating template " "duration; if not given, duration will be read from " "single trigger files") # FIXME : allow choice of SEOBNRv2/v4 or PhenD duration formula ? parser.add_argument("--min-duration", default=0., help="Fudge factor for templates with tiny or negative " "values of template_duration: add to duration values " "before fitting. Units seconds") parser.add_argument("--bin-param", required=True, help="Parameter over which to bin when fitting. Required. " "Choose from mchirp, mtotal, template_duration or a named " "frequency cutoff in pnutils or a frequency function in " "LALSimulation") parser.add_argument("--bin-spacing", choices=["linear", "log"], help="How to space parameter bin edges") parser.add_argument("--num-bins", type=int, help="Number of regularly spaced bins to use over the " " parameter") parser.add_argument("--bin-param-units", help="String to display units of the binning parameter") outputchoice = parser.add_mutually_exclusive_group() outputchoice.add_argument("--plot-dir", help="Plot the fits made, the variation of fitting " "coefficients and the Kolmogorov-Smirnov test values " "and save to the specified directory.") outputchoice.add_argument("--output-file", help="Output a plot of hists and fits made for a single " "threshold value.") parser.add_argument("--user-tag", default="", help="Put a possibly informative string in the names of " "plot files") args = parser.parse_args() args.veto_segment_name = sum(args.veto_segment_name, []) args.veto_file = sum(args.veto_file, []) if len(args.veto_segment_name) != len(args.veto_file): raise RuntimeError("Number of veto files much match veto file names") if (args.prune_param or args.prune_bins or args.prune_number) and not \ (args.prune_param and args.prune_bins and args.prune_number): raise RuntimeError("To prune, need to specify param, number of bins and " "nonzero number to prune in each bin!") if args.output_file is not None and len(args.stat_threshold) > 1: raise RuntimeError("Cannot plot more than one threshold in a single " "output file!") if args.verbose: log_level = logging.DEBUG else: log_level = logging.WARN logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level) statname = "reweighted SNR" if args.sngl_stat == "new_snr" else \ args.sngl_stat.replace("_", " ").replace("snr", "SNR") paramname = args.bin_param.replace("_", " ") paramtag = args.bin_param.replace("_", "") if args.plot_dir: if not args.plot_dir.endswith('/'): args.plot_dir += '/' plotbase = outdir + args.ifo + "-" + args.user_tag logging.info('Opening trigger file: %s' % args.trigger_file) trigf = h5py.File(args.trigger_file, 'r') logging.info('Opening template file: %s' % args.bank_file) templatef = h5py.File(args.bank_file, 'r') # get the stat values minth = min(args.stat_threshold) snr = trigf[args.ifo+'/snr'][:] # SNR cut to reduce memory usage logging.info('Applying SNR cut at %f' % minth) aboveminth = snr >= minth snr = snr[aboveminth] chisq = trigf[args.ifo+'/chisq'][:][aboveminth] chisq_dof = trigf[args.ifo+'/chisq_dof'][:][aboveminth] rchisq = chisq / (2 * chisq_dof - 2) try: sgchisq = trigf[args.ifo+'/sg_chisq'][:][aboveminth] except KeyError: sgchisq = None del chisq del chisq_dof logging.info('Calculating stat values') stat = get_stat(args.sngl_stat, snr, rchisq, sgchisq, args.stat_factor) del snr del rchisq del sgchisq # get the duration values if needed if args.bin_param == 'template_duration' and not args.f_lower: logging.info('Using template duration from the trigger file') trig_dur = True else: trig_dur = False # stat threshold to reduce trigger numbers minth = min(args.stat_threshold) abovethresh = stat >= minth stat = stat[abovethresh] tid = trigf[args.ifo+'/template_id'][:][aboveminth][abovethresh] time = trigf[args.ifo+'/end_time'][:][aboveminth][abovethresh] if trig_dur: tdur = trigf[args.ifo+'/template_duration'][:][aboveminth][abovethresh] logging.info('%i trigs left after thresholding at %f' % (len(stat), minth)) # now do vetoing for veto_file, veto_segment_name in zip(args.veto_file, args.veto_segment_name): retain, junk = events.veto.indices_outside_segments(time, [veto_file], ifo=args.ifo, segment_name=veto_segment_name) stat = stat[retain] tid = tid[retain] time = time[retain] if trig_dur: tdur = tdur[retain] logging.info('%i trigs left after vetoing with %s' % (len(stat), args.veto_file)) ### Functions for doing the pruning (removal of trigs at loudest times) def get_masses(args, tid): m1 = templatef['mass1'][:][tid] m2 = templatef['mass2'][:][tid] s1z = templatef['spin1z'][:][tid] s2z = templatef['spin2z'][:][tid] return m1, m2, s1z, s2z def get_param(args, tag, m1, m2, s1z, s2z): # here used for both pruning and binning params paramarg = getattr(args, tag+'_param') try: # will fail if m1 is a float rather than a sequence logging.info('Getting %s values for %i triggers' % (paramarg, len(m1))) except: pass if paramarg == 'mchirp': parvals, _ = pnutils.mass1_mass2_to_mchirp_eta(m1, m2) elif paramarg == 'mtotal': parvals = m1 + m2 elif paramarg == 'template_duration': # will default to SEOBNRv4 duration function parvals = pnutils.get_imr_duration(m1, m2, s1z, s2z, args.f_lower) parvals += args.min_duration elif paramarg in pnutils.named_frequency_cutoffs.keys(): parvals = pnutils.frequency_cutoff_from_name(paramarg, m1, m2, s1z, s2z) else: # try asking for a LALSimulation frequency function parvals = pnutils.get_freq(paramarg, m1, m2, s1z, s2z) return parvals # get min and max param values in bank if args.prune_param: logging.info('Getting min and max param values') prpars = get_param(args, 'prune', templatef['mass1'][:], templatef['mass2'][:], templatef['spin1z'][:], templatef['spin2z'][:]) minprpar = min(prpars) maxprpar = max(prpars) del prpars logging.info('prune param range %f %f' % (minprpar, maxprpar)) def which_bin(par, minpar, maxpar, nbins, log=False): """ Returns the bin index in which the parameter value belongs (from 0 through nbins-1) when dividing the range between minpar and maxpar equally into nbins bins. """ assert (par >= minpar and par <= maxpar) if log: par, minpar, maxpar = np.log(par), np.log(minpar), np.log(maxpar) # par lies some fraction of the way between min and max frac = float(par - minpar) / float(maxpar - minpar) # binind then lies between 0 and nbins - 1 binind = int(frac * nbins) # except for a corner case if par == maxpar: binind = nbins - 1 return binind if args.prune_param: # hard-coded time window of 0.1s args.prune_window = 0.1 # initialize bin storage prunedtimes = {} for i in range(args.prune_bins): prunedtimes[i] = [] # keep a record of the triggers if all successive loudest events were to # be pruned statpruneall = copy.deepcopy(stat) tidpruneall = copy.deepcopy(tid) timepruneall = copy.deepcopy(time) # many trials may be required to prune in 'quieter' bins for j in range(1000): # are all the bins full already? numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)]) if numpruned == args.prune_bins * args.prune_number: logging.info('Finished pruning!') break if numpruned > args.prune_bins * args.prune_number: logging.error('Uh-oh, we pruned too many things .. %i, to be ' 'precise' % numpruned) raise RuntimeError loudest = np.argmax(statpruneall) lstat = statpruneall[loudest] ltid = tidpruneall[loudest] ltime = timepruneall[loudest] m1, m2, s1z, s2z = get_masses(args, ltid) lbin = which_bin(get_param(args, 'prune', m1, m2, s1z, s2z), minprpar, maxprpar, args.prune_bins, log=args.log_prune_param) # is the bin where the loudest trigger lives full already? if len(prunedtimes[lbin]) == args.prune_number: logging.info('%i - Bin %i full, not pruning event with stat %f at time ' '%.3f' % (j, lbin, lstat, ltime)) # prune the reference trigger array retain = abs(timepruneall - ltime) > args.prune_window statpruneall = statpruneall[retain] tidpruneall = tidpruneall[retain] timepruneall = timepruneall[retain] del retain continue else: logging.info('Pruning event with stat %f at time %.3f in bin %i' % (lstat, ltime, lbin)) # now do the pruning retain = abs(time - ltime) > args.prune_window logging.info('%i trigs before pruning' % len(stat)) stat = stat[retain] logging.info('%i trigs remain' % len(stat)) tid = tid[retain] time = time[retain] if trig_dur: tdur = tdur[retain] # also for the reference trig arrays retain = abs(timepruneall - ltime) > args.prune_window statpruneall = statpruneall[retain] tidpruneall = tidpruneall[retain] timepruneall = timepruneall[retain] # record the time prunedtimes[lbin].append(ltime) del retain del statpruneall del tidpruneall del timepruneall # get binning params after tuning if trig_dur: binpars = tdur + args.min_duration else: m1, m2, s1z, s2z = get_masses(args, tid) binpars = get_param(args, 'bin', m1, m2, s1z, s2z) logging.info("Parameter range of triggers: %f - %f" % (min(binpars), max(binpars))) # get the bins # we assume that parvals are all positive assert min(binpars) >= 0 pmin = 0.999 * min(binpars) pmax = 1.001 * max(binpars) if args.bin_spacing == "linear": pbins = bin_utils.LinearBins(pmin, pmax, args.num_bins) elif args.bin_spacing == "log": pbins = bin_utils.LogarithmicBins(pmin, pmax, args.num_bins) # list of bin indices binind = [pbins[c] for c in pbins.centres()] logging.info("Assigning trigger param values to bins") # FIXME: This is slow!! either find a better way of using pylal.rate # or write faster binning routine pind = np.array([pbins[par] for par in binpars]) logging.info("Getting max counts in bins") # determine trigger counts first to get plot limits to make them the same # for all thresholds; use only the smallest threshold requested bincounts = [] for i in binind: vals_inbin = stat[pind == i] bincounts.append(sum(vals_inbin >= minth)) maxcount = max(bincounts) plotrange = np.linspace(0.95 * min(stat), 1.05 * max(stat), 100) # initialize result storage parbins = {} counts = {} templates = {} alphas = {} stdev = {} ks_prob = {} nabove = {} histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)] for th in args.stat_threshold: logging.info("Fitting above threshold %f" % th) counts[th] = {} alphas[th] = {} stdev[th] = {} ks_prob[th] = {} if args.output_file: fig = plt.figure() for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()): # determine number of templates generating the triggers involved # for hdf5, use the template id; otherwise use masses tid_inbin = tid[pind == i] numtmpl = len(set(tid_inbin)) templates[i] = numtmpl vals_inbin = stat[pind == i] counts[th][i] = sum(vals_inbin >= th) if len(vals_inbin) == 0: logging.info("No trigs in bin %f-%f", (lower, upper)) continue # do the fit alpha, sig_alpha = trstats.fit_above_thresh( args.fit_function, vals_inbin, th) alphas[th][i] = alpha stdev[th][i] = sig_alpha _, ks_prob[th][i] = trstats.KS_test( args.fit_function, vals_inbin, alpha, th) # add histogram to plot histcounts, edges = np.histogram(vals_inbin, bins=50) cum_counts = histcounts[::-1].cumsum()[::-1] binlabel = r"%.3g - %.3g" % (lower, upper) # histogram of fitted values plt.semilogy(edges[:-1], cum_counts, linewidth=2, color=histcolors[i], label=binlabel, alpha=0.6) # fit central value plt.semilogy(plotrange, counts[th][i] * \ trstats.cum_fit(args.fit_function, plotrange, alpha, th), "--", color=histcolors[i], label=r"$\alpha = $%.2f $\pm$ %.2f" % (alpha, sig_alpha)) # 1sigma upper deviation on alpha plt.semilogy(plotrange, counts[th][i] * \ trstats.cum_fit(args.fit_function, plotrange, alpha + \ sig_alpha, th), ":", alpha=0.6, color=histcolors[i]) # 1sigma lower deviation plt.semilogy(plotrange, counts[th][i] * \ trstats.cum_fit(args.fit_function, plotrange, alpha - \ sig_alpha, th), ":", alpha=0.6, color=histcolors[i]) # finish the hist plot leg = plt.legend(labelspacing=0.2) unitstring = " (%s)" % args.bin_param_units if \ args.bin_param_units is not None else "" leg.set_title(paramname+unitstring) plt.setp(leg.get_texts(), fontsize=11) plt.ylim(0.7, 2*maxcount) plt.xlim(0.9*minth, 1.1*max(plotrange)) plt.grid() plt.xlabel(statname, size="large") plt.ylabel("cumulative number", size="large") if args.plot_dir: plt.title(args.ifo + " " + statname + " distribution split by " + \ paramname) dest = plotbase + "_" + args.sngl_stat + "_cdf_by_" + paramtag[0:3] + \ " _fit_thresh_" + str(th) + ".png" logging.info("Saving cumhist to %s" % dest) plt.savefig(dest) elif args.output_file: logging.info("Saving cumhist to %s" % args.output_file) results.save_fig_with_metadata( fig, args.output_file, title="%s: %s histogram of single detector triggers" % (args.ifo, statname), caption=(r"Histogram of single detector %s values binned by %s " "with fitted %s distribution parameterized by α"\ % (statname, paramname, args.fit_function)), cmd=" ".join(sys.argv) ) plt.close() # don't make any more plots if only making the single rainbow hist if args.output_file: exit() # make plots of alpha, trig count and KS significance for th in args.stat_threshold: plt.errorbar(pbins.centres(), [alphas[th][i] for i in binind], yerr=[stdev[th][i] for i in binind], fmt="+-", label=args.ifo + " fit above %.2f" % th) if args.bin_spacing == "log": plt.semilogx() plt.xlim(0.03, 150) plt.ylim(0, 10) plt.grid() plt.legend(loc="best") plt.xlabel(paramname, size="large") plt.ylabel(r"fit parameter $\alpha$", size='large') plt.savefig(plotbase + '_alpha_vs_' + paramtag[0:3] + '.png') plt.close() for th in args.stat_threshold: plt.errorbar(pbins.centres(), [float(counts[th][i])/templates[i] for i in binind], yerr=[counts[th][i]**0.5/templates[i] for i in binind], fmt="+-", label=args.ifo + " trigs above %.2f" % th) if args.bin_spacing == "log": plt.semilogx() plt.xlim(0.03, 150) plt.grid() plt.legend(loc="best") plt.xlabel(paramname, size="large") plt.ylabel(r"Triggers above threshold per template", size='large') plt.savefig(plotbase + '_nabove_vs_' + paramtag[0:3] + '.png') plt.close() for th in args.stat_threshold: plt.plot(pbins.centres(), [ks_prob[th][i] for i in binind], '+--', label=args.ifo+' KS prob, thresh %.2f' % th) if args.bin_spacing == 'log': plt.loglog() else: plt.semilogy() plt.xlim(0.03, 150) plt.grid() leg = plt.legend(loc='best', labelspacing=0.2) plt.setp(leg.get_texts(), fontsize=11) plt.xlabel(paramname, size='large') plt.ylabel('KS test p-value') plt.savefig(plotbase + '_KS_prob_vs_' + paramtag[0:3] + '.png') plt.close() logging.info('Done!')