Revision 4f75faeded2cb284dedbc856a8b2ae56075ea158 authored by Collin Capano on 20 June 2020, 18:27:09 UTC, committed by GitHub on 20 June 2020, 18:27:09 UTC
* use different acl for every chain in epsie

* create base burn in class, move common functions to there; rename MCMCBurnInTests EnsembleMCMC, first stab at creating MCMC tests for independent chains

* more changes to burn in module

* simplify the attributes in the burn in classes

* add write method to burn in classes

* add write_data method to base_hdf

* remove write_burn_in method from mcmc io; use the write method in burn in module instead

* make use of new burn in functions in sampler/base_mcmc

* have emcee and emcee pt use ensemble burn in tests

* add compute_acf function to epsie

* start separating ensemble and mcmc io methods

* stop saving thin settings to file; just return on the fly

* make read/write samples stand alone functions, and update emcee

* rename write functions; update emcee

* move multi temper read/write functions to stand alone and update emcee_pt

* pass kwargs from emcee(_pt) io functions

* simplify get_slice method

* add function to base_mcmc to calculate the number of samples in a chain

* use nsamples_in_chain function to calculate effective number of samples

* add read_raw_samples function that can handle differing number of samples from different chains

* add forgotten import

* use write/read functions from base_multitemper in epsie io

* use stand alone functions for computing ensemble acf/acls

* separate out ensemble-specific attributes in sampler module; update emcee and emcee_pt

* add acl and effective_nsample methods to epsie

* simplify writing acls and burn in

* fix various bugs and typos

* use a single function for writing both acl and raw_acls

* add some more logging info to burn in

* reduce identical blocks of code in burn in module

* fix self -> fp in read_raw_samples

* reduce code duplication in base io and simplify read raw samples function

* fix missed rename

* reduce code redundacy in sampler/base_multitemper

* whitespace

* fix bugs and typos in burn_in module

* fix code climate issues

* use map in compute_acl

* more code climate fixes

* remove unused variable; try to silence pylint

* fix issues reading epsie samples

* only load samples from burned in chains by default

* add act property to mcmc files

* fix act logging message

* fix effective number of samples calculation in epsie

* remap walkers option to chains for reading samples

* fix thinning update

* fix acceptance ratio and temperature data thinning in epsie

* allow for different fields to have differing number of temperatures when loading

* don't try to figure out how many samples will be loaded ahead of time

* store acts in file instead of acls

* write burn in status to file before computing acls

* drop write_acts function

* fix issue with getting specific chains

* fix typo

* code climate issues

* fix plot_acl
1 parent 926b628
# Copyright 2016 Thomas Dent
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

from matplotlib import use
from matplotlib import pyplot as plt

import copy, numpy as np
from scipy import stats as scistats

from pycbc import io, events, bin_utils, results
from import triggers
from import sngl_statistic_dict
import pycbc.version


def get_stat(statchoice, trigs):
    # Initialize statclass with an empty file list. In general could feed it
    # files here for statistics which need that.
    stat_instance = sngl_statistic_dict[statchoice]([])
    return stat_instance.single(trigs)

#### MAIN ####

parser = argparse.ArgumentParser(usage="", description="Plot trigger rates")

parser.add_argument("--version", action=pycbc.version.Version)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
                    help="Input hdf5 file containing single triggers. "
parser.add_argument("--bank-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--veto-file", nargs='*', default=[], action='append',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", nargs='*', default=[], action='append',
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--gps-start-time", type=float,
                    help="Time from which to load and plot triggers")
parser.add_argument("--gps-end-time", type=float,
                    help="End time up to which to load/plot triggers")
parser.add_argument("--sngl-stat", default="new_snr",
                    help="Function of SNR and chisq to threshold on")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics. Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
parser.add_argument("--stat-threshold", type=float,
                    help="Only plot triggers with statistic value above this "
parser.add_argument("--thorne-limit", action="store_true",
                    help="Remove triggers from templates with one or both "
                    "spins above 0.998")
parser.add_argument("--f-lower", type=float, default=0.,
                    help="Starting frequency for calculating template "
                    "duration; if not given, duration will be read from "
                    "single trigger files")
parser.add_argument("--bin-param", required=True,
                    help="Parameter over which to bin. Required. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
parser.add_argument("--bin-spacing", choices=["linear", "log", "irreg"],
                    help="How to space parameter bin edges")
binspec = parser.add_mutually_exclusive_group()
binspec.add_argument("--num-bins", type=int,
                     help="Number of regularly spaced bins to use over the "
                     " parameter")
binspec.add_argument("--irregular-bins", type=float, nargs="*",
                     help="Boundaries of irregular bins")
                    help="String to display units of the binning parameter")
parser.add_argument("--approximant", default="SEOBNRv4",
                    help="Approximant for template duration. Default SEOBNRv4")
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before fitting. Units seconds")
parser.add_argument("--kde-bandwidth", type=float, default=1.,
                    help="Width of the smoothing kernel as compared to total "
                    "plot duration. 0.02 usually works OK")
parser.add_argument("--raw-rate", action="store_true",
                    help="Plot rates without normalizing to template count")
parser.add_argument("--log-y", action="store_true")
                    help="Name of file to save plot")

args = parser.parse_args()

args.veto_segment_name = sum(args.veto_segment_name, [])
args.veto_file = sum(args.veto_file, [])

if len(args.veto_segment_name) != len(args.veto_file):
    raise RuntimeError("Number of veto files much match veto file names")

if (args.gps_start_time or args.gps_end_time) and not (args.gps_start_time \
                                                       and args.gps_end_time):
    raise RuntimeError("I need both gps start time and end time!")

if args.verbose:
    log_level = logging.DEBUG
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

statname = "reweighted SNR" if args.sngl_stat == "new_snr" else \
           args.sngl_stat.replace("_", " ").replace("snr", "SNR")
paramname = args.bin_param.replace("_", " ")
paramtag = args.bin_param.replace("_", "")'Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')'Opening template file: %s' % args.bank_file)
templatef = h5py.File(args.bank_file, 'r')

# get the stat values
stat = get_stat(args.sngl_stat, trigf[args.ifo])

# get the duration values if needed
if args.bin_param == 'template_duration' and not args.f_lower:'Using template duration from the trigger file')
    trig_dur = True
    trig_dur = False

# stat threshold to reduce trigger numbers
abovethresh = stat >= args.stat_threshold
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][abovethresh]
time = trigf[args.ifo+'/end_time'][:][abovethresh]
if trig_dur:
    tdur = trigf[args.ifo+'/template_duration'][:][abovethresh]'%i trigs left after thresholding at %f' % (len(stat), args.stat_threshold))
del stat

if args.gps_start_time and args.gps_end_time:
    inside = np.logical_and(time >= args.gps_start_time, time < args.gps_end_time)
    #stat = stat[inside]
    tid = tid[inside]
    time = time[inside]
    if trig_dur:
        tdur = tdur[inside]'%i trigs left after restricting gps times' % len(time))

# now do vetoing
for vfile, vsegmentname in zip(args.veto_file, args.veto_segment_name):
    retain, junk = events.veto.indices_outside_segments(time, [vfile],
                                       ifo=args.ifo, segment_name=vsegmentname)
    #stat = stat[retain]
    tid = tid[retain]
    time = time[retain]
    if trig_dur:
        tdur = tdur[retain]'%i trigs left after vetoing with %s' %
                                                   (len(time), args.veto_file))

# get a minimum time for plotting purposes
if not args.gps_start_time:
    args.gps_start_time = int(time.min())

plottime = time - args.gps_start_time

if args.thorne_limit:
    m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, tid)
    inside = np.logical_and(abs(s1z) < 0.998, abs(s2z) < 0.998)
    #stat = stat[inside]
    tid = tid[inside]
    time = time[inside]
    if trig_dur:
        tdur = tdur[inside]

def get_pars(args, tag, m1, m2, s1z, s2z):
    # used for binning params
    paramarg = getattr(args, tag+'_param')
        # will fail if m1 is a float rather than a sequence'Getting %s values for %i triggers' % (paramarg, len(m1)))
    return triggers.get_param(paramarg, args, m1, m2, s1z, s2z)

# get binning params
if trig_dur:
    binpars = tdur + args.min_duration
    m1, m2, s1z, s2z = triggers.get_mass_spin(templatef, tid)
    binpars = get_pars(args, 'bin', m1, m2, s1z, s2z)"Parameter range of triggers: %f - %f" %
                                                  (min(binpars), max(binpars)))

# get the bins
# we assume that parvals are all positive
assert min(binpars) >= 0
pmin = 0.999 * min(binpars)
pmax = 1.001 * max(binpars)
bincolors = ['r',(1.0,0.65,0),#'y',
if args.bin_spacing == "linear":
    pbins = bin_utils.LinearBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "log":
    pbins = bin_utils.LogarithmicBins(pmin, pmax, args.num_bins)
elif args.bin_spacing == "irreg":
    # allow bins in reverse order!
    if args.irregular_bins[1] < args.irregular_bins[0]:
        args.irregular_bins = args.irregular_bins[::-1]
        #bincolors = bincolors[::-1]
    pbins = bin_utils.IrregularBins(args.irregular_bins)

# list of bin indices
binind = [pbins[c] for c in pbins.centres()]"Assigning trigger param values to bins")
# This is slow!! Either find a better way of using pylal.rate or write faster binning routine
pind = np.array([pbins[par] for par in binpars])

# initialize result storage
bincounts = {}
bintemplates = {}
maxtime = int(plottime.max()) + 1

minrate = 10000
maxrate = 0.0001

fig = plt.figure()
for i, lower, upper in zip(binind, pbins.lower(), pbins.upper()):
    # determine number of templates generating the triggers involved
    # using the template id
    tid_inbin = tid[pind == i]
    bintemplates[i] = len(set(tid_inbin))
    times_inbin = plottime[pind == i]
    bincounts[i] = len(times_inbin)
    if len(times_inbin) == 0:"No trigs in bin %f-%f", (lower, upper))
    # add KDE to plot
    kd = scistats.gaussian_kde(times_inbin, bw_method=args.kde_bandwidth)
    xplot = np.linspace(0, maxtime, int(10 / args.kde_bandwidth))
    # if the KDE is a normalized PDF, need to rescale to turn it into a rate
    if args.raw_rate:
        yplot = kd(xplot) * bincounts[i]
        yplot = kd(xplot) * bincounts[i] / bintemplates[i]
    minrate = min(minrate, yplot.max()/5e2)
    maxrate = max(maxrate, yplot.max())
    binlabel = r"%.3g - %.3g" % (lower, upper)
    plt.plot(xplot, yplot, '-', c=bincolors[i], label=binlabel)

# finish the plot
leg = plt.legend(labelspacing=0.2, loc='lower center')
unitstring = " (%s)" % args.bin_param_units if \
                                       args.bin_param_units is not None else ""
plt.setp(leg.get_texts(), fontsize=11)
if args.log_y:
plt.ylim(minrate, 1.4*maxrate)
plt.xlim(0, maxtime)
plt.xlabel("GPS time after %i" % args.gps_start_time, size="large")
if args.rate:
    plt.ylabel(r"Trigger rate (s$^{-1}$)", size="large")
    plt.ylabel(r"Rate per template (s$^{-1}$)", size="large")"Saving to %s" % args.output_file)
    fig, args.output_file,
    title="%s: rate of triggers above a %s threshold %f" % (args.ifo,
              statname, args.stat_threshold),
    caption=(r"Rate of %s single detector triggers thresholded on %s" \
             % (args.ifo, statname)),
    cmd=" ".join(sys.argv)
