https://github.com/gwastro/pycbc
Raw File
Tip revision: 66a120a4cb510b288f02076827a3a34eb1697ca0 authored by Gareth S Cabourn Davies on 24 August 2021, 14:54:17 UTC
prep for release (#3768)
Tip revision: 66a120a
pycbc_grb_trig_combiner
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2019 Duncan Macleod
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""Combine triggers from a splitbank GRB run
"""

from __future__ import (division, print_function)

import argparse
import os
from collections import defaultdict

import numpy

import tqdm

import h5py

from gwdatafind.utils import (file_segment, filename_metadata)

from ligo import segments
from ligo.segments.utils import fromsegwizard

from pycbc import __version__

__author__ = "Duncan Macleod <duncan.macleod@ligo.org>"

TQDM_BAR_FORMAT = ("{desc}: |{bar}| "
                   "{n_fmt}/{total_fmt} {unit} ({percentage:3.0f}%) "
                   "[{elapsed} | ETA {remaining}]")
TQDM_KW = {
    "ascii": " -=#",
    "bar_format": TQDM_BAR_FORMAT,
    "smoothing": 0.05,
}


# -- utilties -----------------------------------

def _init_combined_file(h5file, attributes, datasets, **compression_kw):
    h5file.attrs.update(attributes)
    # create datasets
    for dset, (shape, dtype) in datasets.items():
        h5file.create_dataset(dset, shape=shape, dtype=dtype,
                              **compression_kw)


def _merge_dataset(source, target, position, is_event_id=False, ifo_event_id_pos=None):
    """Merge one dataset into another at a given position.

    Parameters
    ----------
    source : `h5py.Dataset`
        the source dataset

    target : `h5py.Dataset`
        the target dataset

    position: `int`
        the index of `target` at which to insert ``source`

    is_event_id : `bool`, optional
        if `True` the values of ``source`` will be incremented by ``position``;
        this should be used to increment the ``*/event_id`` columns by the
        internal count of events

    ifo_event_id_pos : `int`, `None`, optional
        if not `None`, the values of ``source`` will be incremented by
        the given value; this should be used to increment the
        ``network/X1_event_id`` column values by the current count of
        ``X1/event_id`` events

    Returns
    -------
    size : `int`
        the size of the dataset that was just merged
    """
    data = source[()]
    size = data.shape[0]
    # increment network/X1_event_id by count of X1/event_id
    if ifo_event_id_pos:
        data += ifo_event_id_pos
    # increment X1/event_id by number of events recorded so far
    elif is_event_id:
        data += position
    target[position:position+size] = data
    return size


def merge_hdf5_files(inputfiles, outputfile, verbose=False, **compression_kw):
    """Merge several HDF5 files into a single file

    Parameters
    ----------
    inputfiles : `list` of `str`
        the paths of the input HDF5 files to merge

    outputfile : `str`
        the path of the output HDF5 file to write
    """
    attributes = {}
    datasets = {}
    chunk = {}

    nfiles = len(inputfiles)

    def _scan_dataset(name, obj):
        global chunk

        if isinstance(obj, h5py.Dataset) and "/search/" in name:
            datasets[name] = (obj.shape, obj.dtype)
        elif isinstance(obj, h5py.Dataset):
            shape = obj.shape
            dtype = obj.dtype
            try:
                shape = numpy.sum(datasets[name][0] + shape, keepdims=True)
            except KeyError:
                pass
            else:
                assert dtype == datasets[name][1], (
                    "Cannot merge {0}/{1}, does not match dtype".format(
                        obj.file.filename, name,
                ))
            datasets[name] = (shape, dtype)

            # use default compression options from this file
            for copt in ('compression', 'compression_opts'):
                compression_kw.setdefault(copt, getattr(obj, copt))

    # get list of datasets
    attributes = {}
    datasets = {}
    for filename in tqdm.tqdm(inputfiles, desc="Scanning trigger files",
                              disable=not verbose, total=nfiles, unit="files",
                              **TQDM_KW):
        with h5py.File(filename, 'r') as h5f:
            # note: attributes are only recorded from the last file,
            #       since we presume all files have the same attributes
            attributes = dict(h5f.attrs)
            h5f.visititems(_scan_dataset)

    # print summary of what we found
    if verbose:
        print("Found {} events across {} files".format(
            datasets["network/event_id"][0][0],
            nfiles,
        ))

    def _network_first(name):
        """Key function to sort datasets in the network group first

        This is required so that when we merge the network/X1_event_id
        columns, we can reference the X1/event_id event count from the
        _previous_ iteration so get the correct increments.
        """
        if name.startswith("network"):
            return 0
        return 1

    # get ordered list of dataset names to loop over
    dataset_names = sorted(datasets, key=_network_first)

    # get list of event_id columns
    sngl_event_id_names = [
        x for x in dataset_names if x.endswith("/event_id")
    ]
    network_event_id_names = {
        "network/{}".format(x.replace('/', '_')): x for
        x in sngl_event_id_names if not x.startswith("network")
    }

    # handle search datasets as a special case
    # (they will the same in all files)
    search_datasets = set(filter(lambda x: "/search/" in x, dataset_names))

    # record where we are in the global dataset
    position = defaultdict(int)

    with h5py.File(outputfile, 'w') as h5out:
        _init_combined_file(h5out, attributes, datasets, **compression_kw)

        # copy dataset contents
        for filename in tqdm.tqdm(inputfiles, desc="Merging trigger files",
                                  disable=not verbose, total=nfiles,
                                  unit="files", **TQDM_KW):
            with h5py.File(filename, 'r') as h5in:
                for dset in dataset_names:
                    # get event_id handling options
                    if dset in network_event_id_names:
                        ifo_event_id_pos = position[network_event_id_names[dset]]
                    else:
                        ifo_event_id_pos = None

                    # merge dataset
                    position[dset] += _merge_dataset(
                        h5in[dset], h5out[dset], position[dset],
                        is_event_id=dset in sngl_event_id_names,
                        ifo_event_id_pos=ifo_event_id_pos,
                    )

            # only read the search datasets once
            while search_datasets:
                dataset_names.remove(search_datasets.pop())

    if verbose:
        print("Merged triggers written to {}".format(outputfile))
    return outputfile


def bin_events(inputfile, bins, outdir, filetag,
               column="network/end_time_gc", verbose=False):
    """Separate events in the inputfile into bins
    """
    ifotag, _, seg = filename_metadata(inputfile)

    with h5py.File(inputfile, "r") as h5in:
        ifos = [k for k in h5in.keys() if k != "network"]
        times = h5in[column][()]

        for bin_, segl in bins.items():
            # find which network triggers to keep
            if isinstance(segl, list):
                def _in_bin(t):
                    return t in segl
            else:
                def _in_bin(t):
                    return segl[0] <= t < segl[1]

            func = numpy.vectorize(_in_bin, otypes=[bool])
            include = func(times)

            # find which single-ifo events to keep
            ifo_index = {
                ifo: numpy.unique(
                    h5in["network/{}_event_id".format(ifo)][()][include],
                ) for ifo in ifos
            }

            # generate output file
            outf = os.path.join(
                outdir,
                '{}-{}_{}-{}-{}.h5'.format(
                    ifotag, filetag, bin_, seg[0], abs(seg),
                ),
            )

            nsets = sum(isinstance(item, h5py.Dataset) for
                        group in h5in.values() for
                        item in group.values())
            msg = "Slicing {} network events for {}".format(
                include.sum(),
                bin_,
            )
            bar = tqdm.tqdm(total=nsets, desc=msg, disable=not verbose,
                            unit="datasets", **TQDM_KW)
            with h5py.File(outf, "w") as h5out:
                for old in h5in["network"].values():
                    if isinstance(old, h5py.Dataset):
                        h5out.create_dataset(
                            old.name,
                            data=old[()][include],
                            compression=old.compression,
                            compression_opts=old.compression_opts,
                        )
                        bar.update()
                for ifo in ifos:
                    idx = numpy.in1d(h5in[ifo]["event_id"][()], ifo_index[ifo])
                    for old in h5in[ifo].values():
                        if isinstance(old, h5py.Dataset):
                            h5out.create_dataset(
                                old.name,
                                data=old[()][idx],
                                compression=old.compression,
                                compression_opts=old.compression_opts,
                            )
                            bar.update()
            bar.close()
            if verbose:
                print("{} written to {}".format(bin_, outf))


def read_segment_files(segdir):
    segs = {}
    for name, filename in {
            "buffer": "bufferSeg.txt",
            "off": "offSourceSeg.txt",
            "on": "onSourceSeg.txt",
    }.items():
        try:
            with open(os.path.join(segdir, filename), "r") as f:
                segs[name], = fromsegwizard(f)
        except ValueError as exc:
            exc.args = ("more than one segment, an error has occured",)
            raise
    return segs


# -- parse command line -------------------------

parser = argparse.ArgumentParser(
    description=__doc__,
)

parser.add_argument(
    "-v",
    "--verbose",
    action="store_true",
    default=False,
    help="verbose output with microsecond timer (default: %(default)s)",
)
parser.add_argument(
    "-V",
    "--version",
    action="version",
    version=__version__,
    help="show version number and exit",
)

# tags
parser.add_argument(
    "-i",
    "--ifo-tag",
    required=True,
    help="the IFO tag, e.g. H1L1",
)
parser.add_argument(
    "-u",
    "--user-tag",
    default="PYGRB",
    type=str.upper,
    help="the user tag (default: %(default)s)",
)
parser.add_argument(
    "-j",
    "--job-tag",
    type=str.upper,
    help="the job tag, for use when more than one trig_combiner "
         "job is included in a workflow",
)
parser.add_argument(
    "-S",
    "--slide-tag",
    type=str.upper,
    help="the slide tag, used to differentiate long slides",
)

# run parameters
parser.add_argument(
    "-n", "--grb-name",
    type=str.upper,
    help="GRB event name, e.g. 010203",
)
parser.add_argument(
    "-T",
    "--num-trials",
    type=int,
    default=6,
    help="The number of off source trials, default: %(default)d",
)
parser.add_argument(
    "-p",
    "--trig-start-time",
    type=int,
    required=True,
    help="The start time of the analysis segment",
)
parser.add_argument(
    "-a",
    "--segment-dir",
    required=True,
    help="directory holding buffer, on and off source segment files",
)
parser.add_argument(
    "-s",
    "--short-slides",
    action="store_true",
    help="Did analysis use short time slides?",
)
parser.add_argument(
    "-t",
    "--long-slides",
    action="store_true",
    help="Are these triggers from long time slides?",
)

# input/output
parser.add_argument(
    "-f",
    "--input-files",
    nargs="*",
    required=True,
    metavar="TRIGGER FILE",
    help="read in listed trigger files",
)
parser.add_argument(
    "-o",
    "--output-dir",
    default=os.getcwd(),
    help="output directory (default: %(default)s)",
)
parser.add_argument(
    "-c",
    "--no-compression",
    action="store_true",
    default=False,
    help="don't compress output files (default: %(default)s)",
)

args = parser.parse_args()

vprint = print if args.verbose else str

vprint("-- Welcome to the PyGRB trigger combiner")

if args.grb_name:
    args.user_tag += "_GRB{}".format(args.grb_name)
if args.job_tag:
    args.user_tag += "_{}".format(args.job_tag)

analysis = segments.segmentlist([file_segment(args.input_files[0])])
start, end = analysis[0]

if args.no_compression:
    compression_kw = {
        "compression": None,
        "compression_opts": None,
    }
else:
    compression_kw = {}

# -- construct segments -------------------------

segs = read_segment_files(args.segment_dir)
trialtime = abs(segs["on"])
bins = {
    "ONSOURCE": segs["on"],
    "OFFSOURCE": analysis - segments.segmentlist([segs["buffer"]]),
}
vprint("Parsed parameters and generated off-source trials:")
vprint("           trial time : {} seconds".format(trialtime))
vprint("    on-source segment : {}".format(bins["ONSOURCE"]))
vprint("       buffer segment : {}".format(segs["buffer"]))
vprint("  off-source segments : [{}]".format(
    ", \n                         ".join(map(str, bins["OFFSOURCE"])),
))

# construct off-source trial segments
offsource = segments.segmentlist()
_ts = args.trig_start_time
for i in range(args.num_trials):
    _te = _ts + trialtime
    bins["OFFTRIAL_{}".format(i+1)] = seg = segments.segment(_ts, _te)
    _ts = _te
    vprint("          off-trial {} : {}".format(i+1, seg))

# -- read triggers ------------------------------

vprint("-- Merging events")

if args.short_slides and args.long_slides:
    raise NotImplementedError

elif args.short_slides and not args.long_slides:
    raise NotImplementedError

else:
    outfilename = "{}-{}_ALL_TIMES-{}-{}.h5".format(
        args.ifo_tag, args.user_tag, start, end-start,
    )
    outfile = os.path.join(args.output_dir, outfilename)
    merge_hdf5_files(
        args.input_files,
        outfile,
        verbose=args.verbose,
        **compression_kw
    )

vprint("-- Binning events")
bin_events(outfile, bins, args.output_dir, args.user_tag, verbose=args.verbose)

vprint("-- All done")
back to top