Revision e29fea6c08b2ed7df1dc4e152ae5213cc83c0539 authored by Alex Nitz on 18 May 2020, 10:21:17 UTC, committed by GitHub on 18 May 2020, 10:21:17 UTC
1 parent 92bf8f0
Raw File
pycbc_compress_bank
#! /usr/bin/env python
# Copyright (C) 2016  Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
__description__ = \
""""Loads a bank of waveforms and compresses them using the specified
compression algorithm. The resulting compressed waveforms are saved to an
hdf file."""

import argparse
import numpy
import h5py
import logging
import pycbc
from pycbc import psd, DYN_RANGE_FAC
from pycbc.waveform import compress
from pycbc import waveform
from pycbc.types import FrequencySeries, real_same_precision_as
from pycbc import pnutils
from pycbc import filter


parser = argparse.ArgumentParser(description=__description__)
parser.add_argument("--bank-file", type=str, required=True,
                    help="Bank hdf file to load.")
parser.add_argument("--output", type=str, required=True,
                    help="The hdf file to save the templates and "
                    "compressed waveforms to.")
parser.add_argument("--low-frequency-cutoff", type=float, default=None,
                    help="The low frequency cutoff to use for generating "
                    "the waveforms (Hz). If this is not provided, the code "
                    "will look for a low frequency cutoff corressponding "
                    "to each template stored in the bank. If a "
                    "--low-frequency-cutoff is provided and the bank stores "
                    "a low frequency cutoff for each template as well, the "
                    "former is used to generate the waveforms.")
# add approximant arg
pycbc.waveform.bank.add_approximant_arg(parser)
parser.add_argument("--sample-rate", type=int, required=True,
                    help="Half this value sets the maximum frequency of "
                    "the compressed waveforms. Must be a power of 2.")
parser.add_argument("--segment-length", type=int, required=True,
                    help="The segment length to use for generating the "
                    "templates for compression, and for calculating "
                    "overlap for tolerance. Must be a power of 2, "
                    "and at least twice as long as the longest template "
                    "in the bank.")
parser.add_argument("--tmplt-index", nargs=2, type=int, default=None,
                    help="Only generate compressed waveforms for the given "
                    "indices in the bank file. Must provide both a start "
                    "and a stop index. Default is to compress all of the "
                    "templates in the bank file.")
parser.add_argument("--compression-algorithm", type=str, required=True,
                    choices=compress.compression_algorithms.keys(),
                    help="The compression algorithm to use for selecting "
                    "frequency points.")
parser.add_argument("--t-pad", type=float, default=0.001,
                    help="The minimum duration used for t(f) in seconds. "
                    "The inverse of this gives the maximum frequency "
                    "step that will be used in the compressed waveforms. "
                    "Default is 0.001.")
parser.add_argument("--tolerance", type=float, default=0.001,
                    help="The maximum mismatch to allow between the "
                    "interpolated waveform must and the full waveform. "
                    "Points will be added to the compressed waveform "
                    "until its interpolation has a mismatch <= this value. "
                    "Default is 0.001.")
parser.add_argument("--interpolation", type=str, default="inline_linear",
                    help="The interpolation to use for decompressing the "
                    "waveforms for checking tolerance. Options are "
                    "'inline_linear', or any interpolation recognized by "
                    "scipy's interp1d kind argument. Default is inline_linear.")
parser.add_argument("--precision", type=str, choices=["double", "single"],
                    default="single",
                    help="What precision to generate and store the "
                    "waveforms with; default is double.")
parser.add_argument("--force", action="store_true", default=False,
                    help="Overwrite the given hdf file if it exists. "
                    "Otherwise, an error is raised.")
parser.add_argument("--verbose", action="store_true", default=False)

# Insert the PSD options
pycbc.psd.insert_psd_option_group(parser, include_data_options=False)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Verify if the input options for calculating the psd are correct, if
# a PSD is to be used while calculating the overlaps in compress.py . If
# using a PSD is not desired, this check would prevent raising an error
# asking for PSD options.
if args.psd_model or args.psd_file or args.asd_file :
    psd.verify_psd_options(args, parser)

if args.sample_rate % 2 != 0:
    raise ValueError("sample rate must be a power of 2")
if args.segment_length % 2 != 0:
    raise ValueError("segment length must be a power of 2")

df = 1./args.segment_length
fmax = args.sample_rate/2.
N = args.sample_rate * args.segment_length

# load the bank
logging.info("loading bank")
# we'll do everything in double precision; if single is desired, we'll
# cast to single when saving the waveforms
dtype = numpy.complex128

bank = waveform.FilterBank(args.bank_file, N/2+1, df, dtype,
                           low_frequency_cutoff=args.low_frequency_cutoff,
                           approximant=args.approximant,
                           enable_compressed_waveforms=False)
templates = bank.table
if args.tmplt_index is not None:
    imin, imax = args.tmplt_index
    templates = templates[imin:imax]
    bank.table = templates

# generate output file
logging.info("writing template info to output")
output = bank.write_to_hdf(args.output, force=args.force,
                           write_compressed_waveforms=False)

# get the psd
logging.info("getting psd")
psd = pycbc.psd.from_cli(args, length=N/2+1, delta_f=df,
                         low_frequency_cutoff=templates.f_lower.min(),
                         dyn_range_factor=pycbc.DYN_RANGE_FAC,
                         precision='double')

# get the compressed sample points for each template
logging.info("getting compressed amplitude and phase")

# scratch space
decomp_scratch = FrequencySeries(numpy.zeros(N, dtype=dtype), delta_f=df)

for ii in range(templates.size):
    # generate the waveform
    htilde = bank[ii]
    tmplt = bank.table[ii]
    fmin=tmplt.f_lower
    template_duration = htilde.chirp_length
    output['template_duration'][ii] = template_duration
    # check that the segment length is at least twice the template duration
    if args.segment_length < 2*template_duration:
        raise ValueError("segment length is < twice the duration "
                         "({}) of template {}".format(template_duration,
                         tmplt.template_hash))
    kmin = int(numpy.ceil(fmin / df))
    if numpy.abs(htilde[kmin]) == 0:
        raise ValueError("""The amplitude of the waveform at the
                         low_frequency_cutoff is zero. A non-zero value
                         is required.""")
    kmax = numpy.nonzero(abs(htilde))[0][-1]

    # get the compressed sample points
    if args.compression_algorithm == 'mchirp':
        sample_points = compress.mchirp_compression(tmplt.mass1, tmplt.mass2,
            fmin, kmax*df, min_seglen=args.t_pad, df_multiple=df).astype(
            real_same_precision_as(htilde))
    elif args.compression_algorithm == 'spa':
        sample_points = compress.spa_compression(htilde, fmin, kmax*df,
            min_seglen=args.t_pad).astype(real_same_precision_as(htilde))
    else:
        raise ValueError("unrecognized compression algorithm %s" %(
            args.compression_algorithm))

    # compress
    hcompressed = compress.compress_waveform(
        htilde, sample_points, args.tolerance, args.interpolation,
        'double', decomp_scratch=decomp_scratch, psd=psd)

    # save results
    hcompressed.write_to_hdf(output, tmplt.template_hash,
                             precision=args.precision)

logging.info("finished")
bank.filehandler.close()
back to top