Revision a24be950c8a6ded839eff2786fc77d75e15e84ef authored by Ben Lackey on 07 March 2017, 16:59:35 UTC, committed by Ian Harry on 07 March 2017, 16:59:35 UTC
1 parent 903e9ea
Raw File
pycbc_compress_bank
#! /usr/bin/env python
# Copyright (C) 2016  Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
__description__ = \
""""Loads a bank of waveforms and compresses them using the specified
compression algorithm. The resulting compressed waveforms are saved to an
hdf file."""

import argparse
import numpy
import h5py
import logging
import pycbc
from pycbc import psd, DYN_RANGE_FAC
from pycbc.waveform import compress
from pycbc import waveform
from pycbc.types import FrequencySeries, real_same_precision_as
from pycbc import pnutils
from pycbc import filter


parser = argparse.ArgumentParser(description=__description__)
parser.add_argument("--bank-file", type=str,
                    help="Bank hdf file to load.")
parser.add_argument("--output", type=str,
                    help="The hdf file to save the templates and "
                    "compressed waveforms to.")
parser.add_argument("--low-frequency-cutoff", type=float, required=True,
                    help="The low frequency cutoff to use for generating "
                    "the waveforms (Hz).")
# add approximant arg
pycbc.waveform.bank.add_approximant_arg(parser)
parser.add_argument("--sample-rate", type=int, required=True,
                    help="Half this value sets the maximum frequency of "
                    "the compressed waveforms.")
parser.add_argument("--tmplt-index", nargs=2, type=int, default=None,
                    help="Only generate compressed waveforms for the given "
                    "indices in the bank file. Must provide both a start "
                    "and a stop index. Default is to compress all of the "
                    "templates in the bank file.")
parser.add_argument("--compression-algorithm", type=str, required=True,
                    choices=compress.compression_algorithms.keys(),
                    help="The compression algorithm to use for selecting "
                    "frequency points.")
parser.add_argument("--t-pad", type=float, default=0.001,
                    help="The minimum duration used for t(f) in seconds. "
                    "The inverse of this gives the maximum frequency "
                    "step that will be used in the compressed waveforms. "
                    "Default is 0.001.")
parser.add_argument("--tolerance", type=float, default=0.001,
                    help="The maximum mismatch to allow between the "
                    "interpolated waveform must and the full waveform. "
                    "Points will be added to the compressed waveform "
                    "until its interpolation has a mismatch <= this value. "
                    "Default is 0.001.")
parser.add_argument("--interpolation", type=str, default="linear",
                    help="The interpolation to use for decompressing the "
                    "waveforms for checking tolerance. Options are "
                    "'linear', or any interpolation recognized by scipy's "
                    "interp1d kind argument. Default is linear.")
parser.add_argument("--precision", type=str, choices=["double", "single"],
                    default="double",
                    help="What precision to generate and store the "
                    "waveforms with; default is double.")
parser.add_argument("--force", action="store_true", default=False,
                    help="Overwrite the given hdf file if it exists. "
                    "Otherwise, an error is raised.")
parser.add_argument("--verbose", action="store_true", default=False)

# Insert the PSD options
pycbc.psd.insert_psd_option_group(parser, include_data_options=False)

args = parser.parse_args()

pycbc.init_logging(args.verbose)

# Verify if the input options for calculating the psd are correct, if
# a PSD is to be used while calculating the overlaps in compress.py . If
# using a PSD is not desired, this check would prevent raising an error
# asking for PSD options.
if args.psd_model or args.psd_file or args.asd_file :
    psd.verify_psd_options(args, parser)

fmin = args.low_frequency_cutoff
fmax = args.sample_rate/2.

# load the bank
logging.info("loading bank")
# set the dtype to use
# FIXME: there really should be a function in pycbc.types that provides this
# mapping
if args.precision == "single":
    dtype = numpy.complex64
else:
    dtype = numpy.complex128
# we'll just use dummy values for N, df for now
bank = waveform.FilterBank(args.bank_file, 5, 0.25, dtype,
                           low_frequency_cutoff=fmin,
                           approximant=args.approximant)
templates = bank.table
if args.tmplt_index is not None:
    imin, imax = args.tmplt_index
    templates = templates[imin:imax]
    bank.table = templates
else:
    imin, imax = 0, templates.size

# figure out the dfs needed for each waveform
logging.info("getting needed dfs")
# we'll ensure that the there are atleast 2 samples in a waveform
seg_lens = 4*numpy.array([max(4./args.sample_rate,
    2**numpy.ceil(numpy.log2(compress.rough_time_estimate(m1, m2, fmin))))
    for m1,m2 in zip(templates.mass1, templates.mass2)])

# generate output file
logging.info("writing template info to output")
output = bank.write_to_hdf(args.output, force=args.force,
    skip_fields='template_duration')

# get the compressed sample points for each template
logging.info("getting compressed amplitude and phase")
mismatches = numpy.zeros(imax-imin, dtype=float)

# Create a dictionary to store the psd for every value of N that is used
psd_dict = {}

for ii in range(imin, imax):
    # update the N, df to use based on this waveform
    N = int(args.sample_rate*seg_lens[ii-imin])
    df = 1./seg_lens[ii-imin]
    bank.delta_f = df
    bank.N = N
    bank.filter_length = N/2 + 1
    bank.f_lower = fmin
    bank.kmin = int(bank.f_lower / df)
    # scratch space
    decomp_scratch = FrequencySeries(numpy.zeros(N, dtype=dtype), delta_f=df)
    # generate the waveform
    htilde = bank[ii-imin]
    if numpy.abs(htilde[bank.kmin]) == 0:
        raise ValueError("""The amplitude of the waveform at the
                         low_frequency_cutoff is zero. A non-zero value
                         is required.""")
    tmplt = bank.table[ii-imin]
    kmax = htilde.end_idx
    # Calculate psd if psd options are provided as input. The psd would
    # be used to compute the overlap between the full waveform and the
    # decompressed waveform.
    try:
        psd = psd_dict[N]
    except KeyError:
        psd_length = N / 2  + 1
        psd = pycbc.psd.from_cli(args, length=psd_length, delta_f=df,
                                 low_frequency_cutoff=fmin,
                                 dyn_range_factor=pycbc.DYN_RANGE_FAC,
                                 precision=args.precision)
        psd_dict[N] = psd

    # get the compressed sample points
    if args.compression_algorithm == 'mchirp':
        sample_points = compress.mchirp_compression(tmplt.mass1, tmplt.mass2,
            fmin, kmax*df, min_seglen=args.t_pad, df_multiple=df).astype(
            real_same_precision_as(htilde))
    elif args.compression_algorithm == 'spa':
        sample_points = compress.spa_compression(htilde, fmin, kmax*df,
            min_seglen=args.t_pad).astype(real_same_precision_as(htilde))
    else:
        raise ValueError("unrecognized compression algorithm %s" %(
            args.compression_algorithm))

    # compress
    hcompressed = compress.compress_waveform(
        htilde, sample_points, args.tolerance, args.interpolation,
        decomp_scratch=decomp_scratch, psd=psd)

    # save results
    hcompressed.write_to_hdf(output, tmplt.template_hash)

logging.info("finished")
bank.filehandler.close()
back to top