Revision 6fd711f62b16977a90c97b13f81fe90f24f38f1a authored by Sebastian Khan on 25 May 2017, 09:45:24 UTC, committed by GitHub on 25 May 2017, 09:45:24 UTC
* add coaphase to hwinj

allows you to specify the reference orbital
phase when using pycbc_generate_hwinj

coaphase is the phiRef parameter in LALSimulation
but it's called coaphase in sim_inspiral tables.

* fix naming bug

* set default value to zero

* update help message

* fix build and implement ian's comment
1 parent b48e9b1
Raw File
pycbc_live
#!/usr/bin/env python
import argparse, numpy, pycbc, logging, cProfile, h5py, lal
from pycbc import fft, version, waveform, scheme, frame
from pycbc.types import MultiDetOptionAction
from pycbc.filter import LiveBatchMatchedFilter
from pycbc.strain import StrainBuffer
from time import time
import os.path
from mpi4py import MPI as mpi
from pycbc.events import newsnr
from pycbc.events.coinc import LiveCoincTimeslideBackgroundEstimator
from pycbc.events.single import LiveSingleFarThreshold
from pycbc.io.live import SingleCoincForGraceDB
import pycbc.waveform.bank

def makedir(path):
    """
    Make the analysis directory path and any parent directories that don't
    already exist. Will do nothing if path already exists.
    """
    if path is not None and not os.path.exists(path):
        os.makedirs(path)

class LiveEventManager(object):
    def __init__(self, output_path,
                       use_date_prefix=False,
                       ifar_upload_threshold=None,
                       enable_gracedb_upload=False,
                       gracedb_testing=True,
                 ):
        self.path = output_path
        
        # Figure out what we are supposed to process within the pool of MPI processes
        self.comm = mpi.COMM_WORLD
        self.size = self.comm.Get_size()
        self.rank = self.comm.Get_rank()

        self.use_date_prefix = use_date_prefix
        self.ifar_upload_threshold = ifar_upload_threshold
        self.gracedb_testing = gracedb_testing
        self.enable_gracedb_upload = enable_gracedb_upload

    def commit_results(self, results):
        self.comm.gather(results, root=0)

    def barrier(self):
        self.comm.Barrier()

    def barrier_status(self, status):
        return self.comm.allreduce(status, op=mpi.LAND)

    def gather_results(self):
        """ Collect results from the mpi subprocesses and collate them into
        contiguous sets of arrays.
        """

        if self.rank == 0:
            all_results = self.comm.gather(None, root=0)
            data_ends = [a[1] for a in all_results if a is not None]
            results = [a[0] for a in all_results if a is not None]    
        
            combined = {}
            for ifo in results[0]:

                # check if any of the results returned invalid
                try: 
                    for r in results:
                        if r[ifo] is False: 
                            raise ValueError
                except ValueError:
                    continue
                combined[ifo] = {}
                for key in results[0][ifo]:
                    combined[ifo][key] = numpy.concatenate([r[ifo][key] for r in results])

            return combined, data_ends[0]
        else:
            raise TypeError("Not root process")

    def check_coincs(self, ifos, coinc_results, psds, low_frequency_cutoff,
                     data_readers, bank, followup_ifos=None):
        """ Save zerolag triggers to a coinc xml file """
        if 'foreground/ifar' in coinc_results: 
            ifar = coinc_results['foreground/ifar']
            event = SingleCoincForGraceDB(
                    ifos, coinc_results, data_readers=data_readers,
                    bank=bank, upload_snr_series=args.upload_snr_series,
                    followup_ifos=followup_ifos)

            end_time = int(coinc_results['foreground/%s/end_time' % ifos[0]])
            fname = os.path.join(self.path, 'coinc-%s.xml' % end_time)

            comments = ['using ranking statistic: %s' % args.background_statistic]

            if self.enable_gracedb_upload and self.ifar_upload_threshold < ifar:
                event.upload(fname, psds, low_frequency_cutoff, 
                             testing=self.gracedb_testing, extra_strings=comments)
            else:
                event.save(fname)

            logging.info('saving coinc to xml %s', fname)

    def check_singles(self, results, data_reader, psds, low_frequency_cutoff):
        active = [k for k in results if results[k] != None]
        if len(active) == 1:
            ifo = active[0]
            single = sngl_estimator[ifo].check(results[ifo], data_reader[ifo])
            if single is not None:
                fname = 'single-%s-%s.xml' % (ifo, int(single.time))
                fname = os.path.join(self.path, fname)
                if args.enable_single_detector_upload:
                    single.upload(fname, {ifo:psds[ifo]},
                              low_frequency_cutoff,
                              testing=self.gracedb_testing) 
                else:
                    single.save(fname)                   

    def dump(self, results, name, store_psd=False, time_index=None,
                  store_loudest_index=False,
                  raw_results=None,
                  gracedb_results=None):
        """ Save the results from this time block to an hdf output file """
        if self.use_date_prefix:
            tm = lal.GPSToUTC(int(time_index))
            subdir = '%s_%s_%s' % (tm[0], tm[1], tm[2])
            makedir(os.path.join(self.path, subdir))
            fname = os.path.join(self.path, subdir, name) + '.hdf'
        else:
            makedir(self.path)
            fname = os.path.join(self.path, name) + '.hdf'

        f = h5py.File(fname, 'w')
        for ifo in results:
            for k in results[ifo]:
                f['%s/%s' % (ifo, k)] = results[ifo][k]

        for key in raw_results:
            f[key] = raw_results[key]

        if store_loudest_index:
            for ifo in results:
                if 'snr' in results[ifo]:
                    s = numpy.array(results[ifo]['snr'], ndmin=1)
                    c = numpy.array(results[ifo]['chisq'], ndmin=1)
                    nsnr = numpy.array(newsnr(s, c), ndmin=1) if len(s) > 0 else []
                    
                    # loudest by newsnr
                    nloudest = numpy.argsort(nsnr)[::-1][0:store_loudest_index]

                    # loudest by snr
                    sloudest = numpy.argsort(s)[::-1][0:store_loudest_index]
                    f[ifo]['loudest'] = numpy.union1d(nloudest, sloudest)

        f.close()
        if store_psd:
            for ifo in store_psd:
                if store_psd[ifo] is not None:
                    store_psd[ifo].save(fname, group='%s/psd' % ifo)

    def ingest(self, ifo, results):
        pass

parser = argparse.ArgumentParser(description="Look for CBC signals in low "
    "latency. This program takes a fixed template bank and analyzes in fixed "
    "short blocks of time. It is capable of reading from low latency data "
    "directories on the ldg clusters. Work is parallelized through the use "
    "of MPI, so this script should typically be launched with mpirun")
pycbc.waveform.bank.add_approximant_arg(parser)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--version', action='version', version=version.git_verbose_msg)
parser.add_argument('--bank-file', help="Template bank xml file")
parser.add_argument('--low-frequency-cutoff', help="low frequency cutoff", type=int)
parser.add_argument('--sample-rate', help="output sample rate", type=int)
parser.add_argument('--chisq-bins', help="Number of chisq bins")
parser.add_argument('--analysis-chunk', type=int,
                        help="Amount of data to produce triggers in a  block")

parser.add_argument('--snr-threshold', type=float)
parser.add_argument('--snr-abort-threshold', type=float)

parser.add_argument('--channel-name', action=MultiDetOptionAction, nargs='+')
parser.add_argument('--state-channel', action=MultiDetOptionAction, nargs='+', 
          help="The channel containing frame status information. This is used "
               "to determine when to analyze the hoft data. This somewhat "
               "corresponds to CAT1 information")
parser.add_argument('--analyze-flags', nargs='+', default=['HOFT_OK', 'SCIENCE_INTENT'], 
                    help='The flags that must be in the "good" state to analzye data')
parser.add_argument('--data-quality-channel', action=MultiDetOptionAction, nargs='+', 
          help="The channel containing data quality information. This is used "
                "to determine when hoft may be suspect and may be used to veto"
                "triggers or not analyze a segment of data. This roughly "
                "corresponds to CAT2 information")
parser.add_argument('--data-quality-flags', nargs='+', 
          help='The flags used to determine when to throw triggers away')
parser.add_argument('--data-quality-padding', type=float, default=0,
	  help='Time in seconds around a bad dq time to additionally remove triggers')
parser.add_argument('--frame-src', action=MultiDetOptionAction, nargs='+')
parser.add_argument('--frame-type', action=MultiDetOptionAction, nargs='+')
parser.add_argument('--force-update-cache', action='store_true')
parser.add_argument('--highpass-frequency', type=float,
                        help="Frequency to apply highpass filtering")
parser.add_argument('--highpass-reduction', type=float,
                        help="DB to reduce low frequencies")
parser.add_argument('--highpass-bandwidth', type=float,
                        help="Width of the highpass turnover region in Hz")
parser.add_argument('--psd-recalculate-difference', type=float, default=.01)
parser.add_argument('--psd-abort-difference', type=float, default=.20)
parser.add_argument('--psd-samples', type=int, 
                        help="Number of PSD segments to use in the rolling estimate")
parser.add_argument('--psd-segment-length', type=int, 
                        help="Length in seconds of each PSD segment")
parser.add_argument('--psd-inverse-length', type=float, 
                        help="Lenght in time for the equivelant FIR filter")
parser.add_argument('--trim-padding', type=float, default=0.25,
                        help="Padding around the overwhitened analysis block")
parser.add_argument("--enable-bank-start-frequency", action='store_true',
                  help="Read the starting frequency of template waveforms"
                       " from the template bank.")
parser.add_argument('--autogating-threshold', type=float, default=100)
parser.add_argument('--autogating-pad', type=float, default=.25)
parser.add_argument('--autogating-window', type=float, default=0.5)
parser.add_argument('--autogating-cluster', type=float, default=.25)

parser.add_argument('--sync', action='store_true')
parser.add_argument('--increment-update-cache', action=MultiDetOptionAction, nargs='+')
parser.add_argument('--frame-read-timeout', type=float, default=30)
parser.add_argument('--increment', type=int, default=8)

parser.add_argument('--start-time', type=int, default=lal.GPSTimeNow())
parser.add_argument('--end-time', type=int, default=numpy.inf)

parser.add_argument('--output-path')
parser.add_argument('--day-hour-output-prefix', action='store_true')
parser.add_argument('--store-psd', action='store_true')

parser.add_argument('--newsnr-threshold', type=float, default=0)
parser.add_argument('--max-batch-size', type=int, default=2**27)
parser.add_argument('--store-loudest-index', type=int, default=0)
parser.add_argument('--max-psd-abort-distance', type=float, default=numpy.inf)
parser.add_argument('--min-psd-abort-distance', type=float, default=-numpy.inf)
parser.add_argument('--max-triggers-in-batch', type=int)
parser.add_argument('--max-length', type=float,
                    help='Maximum duration of templates, used to set the data buffer size')

parser.add_argument('--enable-profiling', type=int, 
                    help="Dump out profiling information from an MPI process at"
                         " the end of program executrion")

parser.add_argument('--enable-background-estimation', default=False, action='store_true')
parser.add_argument('--ifar-upload-threshold', type=float)
parser.add_argument('--enable-gracedb-upload', action='store_true', default=False)
parser.add_argument('--file-prefix', default='Live')
parser.add_argument('--enable-production-gracedb-upload', action='store_true', default=False)
parser.add_argument('--enable-single-detector-upload', action='store_true', default=False)
parser.add_argument('--upload-snr-series', action='store_true', default=False)
parser.add_argument('--followup-detectors', default=[], nargs='*',
                    help='List of detectors to use as followup')
parser.add_argument('--enable-single-detector-background', action='store_true', default=False)
parser.add_argument('--round-start-time', type=int,
                    help="Round up the start time to the nearest multiple of X"
                         " seconds. This is useful for forcing agreement "
                         " with frame file production.")

scheme.insert_processing_option_group(parser)
LiveSingleFarThreshold.insert_args(parser)
fft.insert_fft_option_group(parser)
LiveCoincTimeslideBackgroundEstimator.insert_args(parser)
args = parser.parse_args()
scheme.verify_processing_options(args, parser)

import platform
pycbc.init_logging(args.verbose, format="%(asctime)s: %(message)s" + ": %s: " % platform.node())

ctx = scheme.from_cli(args)

sr = args.sample_rate
flow = args.low_frequency_cutoff

if args.round_start_time:
    args.start_time = int(args.start_time / args.round_start_time + 1) * args.round_start_time 
    logging.info('Starting from: %s', args.start_time)

# Approximant guess of the total padding
valid_pad = args.analysis_chunk
total_pad = args.trim_padding * 2 + valid_pad
bank = waveform.LiveFilterBank(args.bank_file, sr, total_pad,
                       low_frequency_cutoff=None if args.enable_bank_start_frequency else flow,
                       approximant=args.approximant,
                       increment=args.increment)

evnt = LiveEventManager(args.output_path, 
                        use_date_prefix=args.day_hour_output_prefix, 
                        ifar_upload_threshold=args.ifar_upload_threshold, 
                        enable_gracedb_upload=args.enable_gracedb_upload,
                        gracedb_testing=not args.enable_production_gracedb_upload,
                       )

# ifos used for primary analysis (only two at the moment)
ifos = set(args.channel_name.keys())
followup_ifos = set(args.followup_detectors)
ifos -= followup_ifos
ifos, followup_ifos = list(ifos), list(followup_ifos)

# the root must also advance followup ifos, other threads don't need to
if evnt.rank == 0:
    advancing_ifos = ifos + followup_ifos
else:
    advancing_ifos = ifos

# I'm not the root, so do some actually filtering.
with ctx:
    try:
        fft.from_cli(args)
    except:
        pass
    maxlen = args.psd_segment_length * (args.psd_samples / 2 + 1)
    if evnt.rank > 0:
        bank.table.sort(order='mchirp')
        waveforms = list(bank[evnt.rank-1::evnt.size-1])
        lengths = numpy.array([1.0 / waveform.delta_f for waveform in waveforms])
        psd_len = args.psd_segment_length * (args.psd_samples / 2 + 1)
        maxlen = max(lengths.max(), psd_len)
        mf = LiveBatchMatchedFilter(waveforms, args.snr_threshold, args.chisq_bins,
                                     snr_abort_threshold=args.snr_abort_threshold, 
                                     newsnr_threshold=args.newsnr_threshold,
                                     max_triggers_in_batch=args.max_triggers_in_batch,
                                     maxelements=args.max_batch_size)
    if args.max_length is not None:
        maxlen = args.max_length
    maxlen = int(maxlen)
    data_reader = {}
    sngl_estimator = {}
    for ifo in advancing_ifos:
        state = '%s:%s' % (ifo, args.state_channel[ifo]) if args.state_channel else None
        dq ='%s:%s' % (ifo, args.data_quality_channel[ifo]) if args.data_quality_channel else None

        if args.frame_type:
            frame_src = frame.frame_paths(args.frame_type[ifo], args.start_time, args.end_time)
        else:
            frame_src = [args.frame_src[ifo]]

        if args.enable_single_detector_background and evnt.rank == 0:
            sngl_estimator[ifo] = LiveSingleFarThreshold.from_cli(args, ifo)

        data_reader[ifo] = StrainBuffer(frame_src,
                            '%s:%s' % (ifo, args.channel_name[ifo]),
                            args.start_time, max_buffer=maxlen * 2,
                            state_channel=state,
                            data_quality_channel=dq,
                            sample_rate=args.sample_rate,
                            low_frequency_cutoff=args.low_frequency_cutoff,
                            highpass_frequency=args.highpass_frequency,
                            highpass_reduction=args.highpass_reduction,
                            highpass_bandwidth=args.highpass_bandwidth,
                            psd_samples=args.psd_samples,
                            trim_padding=args.trim_padding,
                            psd_segment_length=args.psd_segment_length,
                            psd_inverse_length=args.psd_inverse_length,
                            autogating_threshold=args.autogating_threshold,
                            autogating_cluster=args.autogating_cluster,
                            autogating_window=args.autogating_window,
                            autogating_pad=args.autogating_pad,
                            psd_abort_difference=args.psd_abort_difference,
                            psd_recalculate_difference=args.psd_recalculate_difference,
                            force_update_cache=args.force_update_cache,
                            increment_update_cache=args.increment_update_cache[ifo],
                            analyze_flags=args.analyze_flags,
                            data_quality_flags=args.data_quality_flags)
    if args.enable_background_estimation and evnt.rank == 0:
        estimator = LiveCoincTimeslideBackgroundEstimator.from_cli(args,
                            len(bank), args.analysis_chunk, ifos)



    logging.info('%s: Starting... ', evnt.rank)

    if args.enable_profiling is not None and evnt.rank == args.enable_profiling:
        pr = cProfile.Profile()
        pr.enable()

    # get more data
    data_end = lambda: data_reader[data_reader.keys()[0]].end_time
    i = 0
    while data_end() < args.end_time:
        t1 = time()
        logging.info('%s: Analyzing up to %s', evnt.rank, data_end())
        results = {}

        for ifo in advancing_ifos:
            results[ifo] = False
            status = data_reader[ifo].advance(valid_pad, timeout=args.frame_read_timeout)

            if status is True:
                status = data_reader[ifo].recalculate_psd()

            if data_reader[ifo].psd is not None:
                dist = data_reader[ifo].psd.dist
                if dist < args.min_psd_abort_distance or dist > args.max_psd_abort_distance:
                    logging.info("PSD outside acceptable sensitivity %s - %s, have %s",
                                 args.min_psd_abort_distance, args.max_psd_abort_distance, dist)
                    status = False               

            if ifo in followup_ifos:
                # we advanced the data and handled the PSD, that's it
                continue

            if status is True:
                logging.info('%s: Filtering: %s', evnt.rank, ifo)
                if evnt.rank > 0:
                    results[ifo] = mf.process_data(data_reader[ifo])
            else:
                logging.info('Insufficient data for analysis')

        if evnt.rank > 0:
            evnt.commit_results((results, data_end()))
        else:
            psds = {}
            for ifo in data_reader:
                psds[ifo] = data_reader[ifo].psd

            # Collect together the single detector triggers
            if evnt.size > 1:
                results, valid_end = evnt.gather_results()

            # Veto single detector triggers if they fail the DQ vector
            if args.data_quality_channel:
                logging.info('checking the dq vector')
                for ifo in results:
                    start = data_reader[ifo].start_time
                    end = data_reader[ifo].end_time
                    times = results[ifo]['end_time']
                    idx = data_reader[ifo].dq.indices_of_flag(start, valid_pad, times,
                                                              padding=args.data_quality_padding)
                    logging.info('Keeping %s triggers out of %s triggers in %s', len(idx), len(times), ifo)
                    for key in results[ifo]:
                        if len(results[ifo][key]) > 0:
                            results[ifo][key] = results[ifo][key][idx]

            # Look for coincident triggers and do background estimation
            coinc_results = {}
            if args.enable_background_estimation:
                coinc_results = estimator.add_singles(results, data_reader)
                evnt.check_coincs(results.keys(), coinc_results,
                                  psds, args.low_frequency_cutoff, data_reader, bank,
                                  followup_ifos=followup_ifos)

            # Check for singles if we don't have coinc time
            if args.enable_single_detector_background:
                evnt.check_singles(results, data_reader, psds, args.low_frequency_cutoff)

            prefix = '%s-%s-%s-%s' % (''.join(ifos), args.file_prefix, data_end() - args.analysis_chunk, valid_pad)
            evnt.dump(results, prefix, time_index=data_end(),
                      store_psd=False if args.store_psd is False else psds,
                      store_loudest_index=args.store_loudest_index,
                      raw_results=coinc_results,
                     )     

        if args.sync: evnt.barrier()
        tdiff = time() - t1
        logging.info('%s: Took %1.2f, duty factor of %.2f', evnt.rank, tdiff, tdiff / valid_pad)
        i += 1

if args.fftw_output_float_wisdom_file:
    fft.fftw.export_single_wisdom_to_filename(args.fftw_output_float_wisdom_file)

if args.fftw_output_double_wisdom_file:
    fft.fftw.export_double_wisdom_to_filename(args.fftw_output_double_wisdom_file)

if args.enable_profiling is not None and args.enable_profiling == evnt.rank:
    pr.dump_stats('log')
back to top