swh:1:snp:3a699297f000109a1bc833f294a54171df990207
Raw File
Tip revision: 9a7ef504439bf7295dd910081c2780077e16fad0 authored by Alex Nitz on 20 March 2024, 15:51:52 UTC
don't keep file open within injectionset (#4667)
Tip revision: 9a7ef50
pycbc_make_skymap
#!/usr/bin/env python

"""Fast and simple sky localization for a specific compact binary merger event.
Runs a single-template matched filter on strain data from a number of detectors
and calls BAYESTAR to produce a sky localization from the resulting set of SNR
time series."""

import os
import argparse
import logging
import subprocess
import tempfile
import shutil

# we will make plots on a likely headless machine, so make sure matplotlib's
# backend is set appropriately
from matplotlib import use as mpl_use_backend
mpl_use_backend('agg')

import numpy as np
import h5py

from ligo.gracedb.rest import GraceDb

import pycbc
from pycbc.filter import sigmasq
from pycbc.io import live, WaveformArray
from pycbc.types import (TimeSeries, FrequencySeries, load_timeseries,
                         load_frequencyseries, MultiDetMultiColonOptionAction,
                         MultiDetOptionAction, MultiDetOptionAppendAction)
from pycbc.pnutils import nearest_larger_binary_number
from pycbc.waveform.spa_tmplt import spa_length_in_time
from pycbc import frame, waveform, dq
from pycbc.psd import interpolate


def default_frame_type(time, ifo):
    """Sensible defaults for frame types based on interferometer and time.
    """
    if time < 1137254517:
        # O1
        if ifo in ['H1', 'L1']:
            return ifo + '_HOFT_C02'
    elif time >= 1164556717 and time < 1235433618:
        # O2
        if ifo == 'V1':
            return 'V1O2Repro2A'
        elif ifo in ['H1', 'L1']:
            return ifo + '_CLEANED_HOFT_C02'
    elif time >= 1235433618:
        # O3
        if ifo == 'V1':
            return 'V1Online'
        elif ifo in ['H1', 'L1']:
            return ifo + '_HOFT_CLEAN_SUB60HZ_C01'
    raise ValueError('Detector {} not supported at time {}'.format(ifo, time))

def default_channel_name(time, ifo):
    """Sensible defaults for channel name based on interferometer and time.
    """
    if time < 1137254517:
        # O1
        if ifo in ['H1', 'L1']:
            return ifo + ':DCS-CALIB_STRAIN_C02'
    elif time > 1164556717 and time < 1235433618:
        # O2
        if ifo == 'V1':
            return ifo + ':Hrec_hoft_V1O2Repro2A_16384Hz'
        elif ifo in ['H1', 'L1']:
            return ifo + ':DCH-CLEAN_STRAIN_C02'
    elif time >= 1235433618:
        # O3
        if ifo == 'V1':
            return ifo + ':Hrec_hoft_16384Hz'
        elif ifo in ['H1', 'L1']:
            return ifo + ':DCS-CALIB_STRAIN_CLEAN_SUB60HZ_C01'
    raise ValueError('Detector {} not supported at time {}'.format(ifo, time))

def main(trig_time, mass1, mass2, spin1z, spin2z, f_low, f_upper, sample_rate,
         ifar, ifos, thresh_SNR, ligolw_skymap_output='.',
         ligolw_event_output=None, snr_series_duration=0.1465,
         frame_types=None, channel_names=None,
         segment_source=None, segment_server=None,
         gracedb_server=None, test_event=True,
         custom_frame_files=None, approximant=None, detector_state=None,
         veto_definer=None, injection_file=None, fake_strain=None,
         fake_strain_seed=None, rescale_loglikelihood=None):

    if not test_event and not gracedb_server:
        raise RuntimeError('a GraceDB URL must be specified if not a test event.')

    tmpdir = tempfile.mkdtemp()

    if len(trig_time) == 1 and ':' not in trig_time[0]:
        # single approximate time given
        mean_trig_time = float(trig_time[0])
        trig_time_mode = 'approximate'
    else:
        # precise per-detector times given
        trig_time = {kv.split(':')[0]: float(kv.split(':')[1])
                     for kv in trig_time}
        mean_trig_time = np.mean([trig_time[ifo] for ifo in trig_time])
        trig_time_mode = 'exact'

    if frame_types is None:
        frame_types = {}
    if channel_names is None:
        channel_names = {}
    for ifo in ifos:
        if ifo not in frame_types:
            frame_types[ifo] = default_frame_type(mean_trig_time, ifo)
        if ifo not in channel_names:
            if fake_strain[ifo] is not None:
                channel_names[ifo] = ifo + ':FAKE_DATA'
            else:
                channel_names[ifo] = default_channel_name(mean_trig_time, ifo)

    # resulting files will be tagged with this string
    file_name_tag = '{:.0f}'.format(mean_trig_time)

    # parameters to fit a single-template inspiral job nicely
    # around the trigger time, without requiring too much data

    pad_data = 8

    # Padding set by 16 * 2 for psd and buffer for other filtering
    pad = 40
    template_duration = spa_length_in_time(mass1=mass1, mass2=mass2,
                                           f_lower=f_low, phase_order=-1)
    segment_length = int(nearest_larger_binary_number(template_duration + pad))
    # set minimum so there is enough for a psd estimate
    if segment_length < 128:
        segment_length = 128
    logging.info('Using segment length: %s', segment_length)

    gps_end_time = int(mean_trig_time + pad / 2)
    gps_start_time = gps_end_time - segment_length
    logging.info("Using data: %s-%s", gps_start_time, gps_end_time)

    # if requested, remove detectors with missing/vetoed data
    # over the required segment
    required_duration = (gps_end_time + pad_data) - (gps_start_time - pad_data)
    unavailable_ifos = set()
    for ifo in ifos:
        if detector_state is None or ifo not in detector_state:
            continue
        on_segs = dq.query_str(ifo, detector_state[ifo],
                               gps_start_time - pad_data,
                               gps_end_time + pad_data,
                               veto_definer=veto_definer,
                               source=segment_source,
                               server=segment_server)
        if abs(on_segs) < required_duration:
            logging.info('Excluding %s due to missing or vetoed data', ifo)
            unavailable_ifos.add(ifo)
    ifos = sorted(set(ifos) - unavailable_ifos)
    if not ifos:
        raise RuntimeError('All detectors have been excluded due to '
                           'missing or vetoed data')

    highpass_frequency = int(f_low * 0.7)
    logging.info("Setting highpass: %s Hz", highpass_frequency)

    procs = []
    st_psd_paths = {}
    st_out_paths = {}
    for ifo in ifos:
        # compose the command line for the single-template process
        st_psd_paths[ifo] = os.path.join(
                tmpdir, 'PSD_{}_{}.txt'.format(file_name_tag, ifo))
        st_out_paths[ifo] = os.path.join(
                tmpdir, 'SNRTS_{}_{}.hdf'.format(file_name_tag, ifo))

        command = ["pycbc_single_template",
        "--verbose",
        "--segment-length", str(segment_length),
        "--segment-start-pad", "0",
        "--segment-end-pad", "0",
        "--psd-estimation", "median",
        "--psd-segment-length", "16",
        "--psd-segment-stride", "8",
        "--psd-inverse-length", "16",
        "--order", "-1",
        "--taper-data", "1",
        "--allow-zero-padding",
        "--autogating-threshold", "50",
        "--autogating-cluster", "0.1",
        "--autogating-width", "0.25",
        "--autogating-taper", "0.25",
        "--autogating-pad", "16",
        "--autogating-max-iterations", "5",
        "--strain-high-pass", str(highpass_frequency),
        "--pad-data", str(pad_data),
        "--chisq-bins", '0.72*get_freq("fSEOBNRv4Peak",params.mass1,params.mass2,params.spin1z,params.spin2z)**0.7',
        "--sample-rate", str(sample_rate),
        "--mass1", str(mass1),
        "--mass2", str(mass2),
        "--spin1z", str(spin1z),
        "--spin2z", str(spin2z),
        "--low-frequency-cutoff", str(f_low),
        "--gps-start-time", str(gps_start_time),
        "--gps-end-time", str(gps_end_time),
        "--trigger-time", '{:.6f}'.format(trig_time[ifo] if ifo in trig_time
                                          else mean_trig_time),
        "--window", '1',
        "--channel-name", channel_names[ifo],
        "--psd-output", st_psd_paths[ifo],
        "--output-file", st_out_paths[ifo]]

        command.append("--approximant")
        for apx in approximant:
            command.append(apx)
        if f_upper:
            command.append("--high-frequency-cutoff")
            command.append(str(f_upper))

        if injection_file:
            command.append("--injection-file")
            command.append(injection_file)

        if fake_strain[ifo] is not None:
            # use fake strain for this ifo
            command.append("--fake-strain")
            command.append(fake_strain[ifo])
            if fake_strain_seed[ifo] is not None:
                command.append("--fake-strain-seed")
                command.append(str(fake_strain_seed[ifo]))
        elif custom_frame_files is None or ifo not in custom_frame_files:
            # use default guesses for this ifo
            command.append("--frame-type")
            command.append(frame_types[ifo])
        else:
            # use custom frame files for this ifo
            logging.info("Check if the segment in the custom frame file is safe")
            fr_start_times = []
            fr_end_times = []
            for custom_frame in custom_frame_files[ifo]:
                try:
                    frame_data = frame.read_frame(custom_frame, channel_names[ifo])
                except RuntimeError:
                    msg = 'Channel name in {} is not {}'.format(
                            custom_frame, channel_names[ifo])
                    raise RuntimeError(msg)
                fr_start_times.append(frame_data.start_time)
                fr_end_times.append(frame_data.end_time)
            if gps_start_time < min(fr_start_times):
                msg = 'Start time of {} must be before the required start time {}'
                msg = msg.format(min(fr_start_times), gps_start_time)
                raise RuntimeError(msg)
            if max(fr_end_times) < gps_end_time:
                msg = 'End time of {} must be after the required end time {}'
                msg = msg.format(max(fr_end_times), gps_end_time)
                raise RuntimeError(msg)
            if mean_trig_time < min(fr_start_times) \
                    or max(fr_end_times) < mean_trig_time:
                msg = 'Trigger time must be within your frame file(s)'
                raise RuntimeError(msg)

            command.append("--frame-files")
            for custom_frame in custom_frame_files[ifo]:
                command.append(custom_frame)

        # create and open a file to record the
        # single-template process' stdout and stderr
        log_path = os.path.join(
                tmpdir,
                'pycbc_single_template_{}_{}_log.txt'.format(file_name_tag, ifo))
        log_file = open(log_path, 'w')
        log_file.write(' '.join(command) + '\n\n')
        log_file.flush()

        # start the single-template process
        proc = subprocess.Popen(command, stdout=log_file, stderr=log_file)
        procs.append((ifo, proc, log_path, log_file))

    logging.info('Waiting for pycbc_single_template to complete')

    snr_errors = False
    for ifo, proc, log_path, log_file in procs:
        proc.wait()
        if proc.returncode != 0:
            logging.error('%s pycbc_single_template failed, see %s',
                          ifo, log_path)
            snr_errors = True
        log_file.close()
    if snr_errors:
        raise RuntimeError('one or more pycbc_single_template failed, '
                           'please see the messages above for details.')

    logging.info('Gathering info from pycbc_single_template results')

    coinc_results = {}
    subthreshold_ifos = []
    network_snr2 = 0.
    for ifo in ifos:
        snr_series = load_timeseries(st_out_paths[ifo], group='snr')
        chisq_series = load_timeseries(st_out_paths[ifo], group='chisq')
        psd_series = load_frequencyseries(st_psd_paths[ifo])
        psd_series *= pycbc.DYN_RANGE_FAC ** 2.0
        psd_series = psd_series.astype(np.float32)
        template_series = load_frequencyseries(st_out_paths[ifo],
                                               group='template')

        key = 'foreground/{}/'.format(ifo)
        coinc_results[key + 'mass1'] = mass1
        coinc_results[key + 'mass2'] = mass2
        coinc_results[key + 'spin1z'] = spin1z
        coinc_results[key + 'spin2z'] = spin2z
        coinc_results[key + 'f_lower'] = f_low
        if f_upper:
            coinc_results[key + 'f_final'] = f_upper
        # required by CandidateForGraceDB
        coinc_results[key + 'template_id'] = -1
        coinc_results[key + 'snr_series'] = snr_series
        coinc_results[key + 'psd_series'] = psd_series
        coinc_results[key + 'sigmasq'] = sigmasq(template_series, psd_series)

        if trig_time_mode == 'approximate':
            # find the absolute max SNR peak in the entire SNR time
            # series; it will be the trigger time in this detector
            # if above threshold
            # FIXME there is no guarantee that the light travel times
            # between the peaks in different detectors will be physical.
            # Any good ideas for how to enforce this in a simple way?
            snr_series_peak = np.argmax(abs(snr_series))
            snr = abs(snr_series[snr_series_peak])
            logging.info('%s peak SNR: %.2f', ifo, snr)
            if snr < thresh_SNR:
                subthreshold_ifos.append(ifo)
                continue
        elif trig_time_mode == 'exact':
            if ifo not in trig_time:
                # treat this detector as subthreshold
                subthreshold_ifos.append(ifo)
                continue
            snr_series_peak = int((trig_time[ifo] - snr_series.start_time) \
                    / snr_series.delta_t)
            assert snr_series_peak >= 0
            assert snr_series_peak < len(snr_series)
            snr = abs(snr_series[snr_series_peak])
            logging.info('%s SNR at given time: %.2f', ifo, snr)

        coinc_results[key + 'end_time'] = \
                float(snr_series.sample_times[snr_series_peak])
        coinc_results[key + 'snr'] = snr
        coinc_results[key + 'coa_phase'] = np.angle(snr_series[snr_series_peak])
        coinc_results[key + 'chisq'] = chisq_series[snr_series_peak]

        network_snr2 += snr ** 2

    trig_ifos = sorted(set(ifos) - set(subthreshold_ifos))
    if not trig_ifos:
        raise RuntimeError(f'All interferometers have SNR below threshold '
                           '{thresh_SNR}. Is this really a candidate event?')

    coinc_results['foreground/stat'] = network_snr2 ** 0.5
    coinc_results['foreground/ifar'] = ifar

    # BAYESTAR's convention for end time of subthreshold detectors
    subthreshold_sngl_time = np.mean([
        coinc_results['foreground/%s/end_time' % ifo] for ifo in trig_ifos])

    window_bins = int(snr_series_duration * sample_rate)
    skyloc_data = {}
    for ifo in ifos:
        key = 'foreground/{}/'.format(ifo)

        if ifo in subthreshold_ifos:
            coinc_results[key + 'end_time'] = subthreshold_sngl_time

        skyloc_data[ifo] = {
            'psd': interpolate(coinc_results[key + 'psd_series'], 0.25)
        }

        # Go back to the SNR time series and select a piece centered on the
        # end time
        time = coinc_results[key + 'end_time']
        snr_series = coinc_results[key + 'snr_series']
        peak_bin = int((time - snr_series.start_time) / snr_series.delta_t)
        max_bin = peak_bin + window_bins + 1
        if max_bin > len(snr_series):
            window_bins = len(snr_series) - peak_bin
            max_bin = len(snr_series)
            min_bin = peak_bin - window_bins + 1
        else:
            min_bin = peak_bin - window_bins
        if min_bin < 0:
            window_bins = peak_bin
            max_bin = peak_bin + window_bins + 1
            min_bin = peak_bin - window_bins

        skyloc_data[ifo]['snr_series'] = \
                snr_series[min_bin:max_bin].astype(np.complex64)

    kwargs = {'psds': {ifo: skyloc_data[ifo]['psd'] for ifo in ifos},
              'low_frequency_cutoff': f_low,
              'high_frequency_cutoff': f_upper,
              'skyloc_data': skyloc_data,
              'channel_names': channel_names}

    doc = live.CandidateForGraceDB(
        trig_ifos,
        trig_ifos,
        coinc_results,
        **kwargs
    )

    ligolw_file_path = os.path.join(tmpdir, file_name_tag + '.xml')

    if gracedb_server:
        comments = ['Manual followup from PyCBC']
        gid = doc.upload(ligolw_file_path,
                         gracedb_server=gracedb_server,
                         testing=test_event,
                         extra_strings=comments)
        gracedb = GraceDb(gracedb_server)
    else:
        doc.save(ligolw_file_path)

    # Defining the approximant to feed into BAYESTAR
    row = WaveformArray.from_kwargs(
        mass1=mass1,
        mass2=mass2,
        spin1z=spin1z,
        spin2z=spin2z)
    bs_approximant = waveform.bank.parse_approximant_arg(approximant, row)[0]

    # BAYESTAR uses TaylorF2 instead of SPAtmplt
    if bs_approximant == 'SPAtmplt':
        bs_approximant = 'TaylorF2'

    # run BAYESTAR to generate the skymap
    cmd = ['bayestar-localize-coincs',
           ligolw_file_path,
           ligolw_file_path,
           '--waveform', str(bs_approximant),
           '--f-low', str(f_low),
           '-o', tmpdir]
    if rescale_loglikelihood is not None:
        cmd += ['--rescale-loglikelihood', rescale_loglikelihood]
    subprocess.call(cmd)

    skymap_fits_name = os.path.join(tmpdir, '0.fits')

    # plot the skymap

    skymap_plot_path = os.path.join(ligolw_skymap_output,
                                    file_name_tag + '_skymap.png')
    cmd = ['ligo-skymap-plot',
           skymap_fits_name,
           '-o', skymap_plot_path,
           '--contour', '50', '90',
           '--annotate']
    subprocess.call(cmd)

    final_fits_path = os.path.join(ligolw_skymap_output,
                                   file_name_tag + '.fits')
    shutil.move(skymap_fits_name, final_fits_path)

    if gracedb_server:
        gracedb.write_log(
            gid,
            'Bayestar skymap FITS file upload',
            filename=final_fits_path,
            tag_name=['sky_loc'],
            displayName=['Bayestar FITS skymap']
        )
        gracedb.write_log(
            gid,
            'Bayestar skymap plot upload',
            filename=skymap_plot_path,
            tag_name=['sky_loc'],
            displayName=['Bayestar skymap plot']
        )

    if ligolw_event_output:
        shutil.move(ligolw_file_path, ligolw_event_output)
    shutil.rmtree(tmpdir)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description=__doc__)
    pycbc.add_common_pycbc_options(parser)
    parser.add_argument('--version', action=pycbc.version.Version)
    # note that I am not using a MultiDetOptionAction for --trig-time as I
    # explicitly want to handle cases like `--trig-time 1234` and
    # `--trig-time H1:1234 L1:1234` in different ways
    parser.add_argument('--trig-time', required=True, nargs='+',
                        help='GPS time of trigger. Can be a single value, or a list of per-detector times'
                             'using the format H1:123456 L1:123457. In the first case, the given time is only '
                             'an approximate guess of the trigger time, and '
                             'the code will try to find the maximum-SNR peaks '
                             'at each detector on its own. The given time should be within 1 s '
                             'of the peak. This procedure becomes unreliable '
                             'for low-SNR triggers. When times are given at specific detectors, instead, they '
                             'literally indicate which sample of the '
                             'SNR time series should be taken as peak SNR at '
                             'the given detector. Unspecified detectors will '
                             'be considered subthreshold and their time will '
                             'be chosen as the mean of the given times.')
    parser.add_argument('--mass1', type=float, required=True)
    parser.add_argument('--mass2', type=float, required=True)
    parser.add_argument('--spin1z', type=float, required=True)
    parser.add_argument('--spin2z', type=float, required=True)
    parser.add_argument('--approximant', type=str, nargs='+',
                        default=["SPAtmplt:mtotal<4", "SEOBNRv4_ROM:else"])
    parser.add_argument('--f-low', type=float,
                        help="lower frequency cut-off (float)", default=20.0)
    parser.add_argument('--f-upper', type=float,
                        help="upper frequency cut-off (float)")
    parser.add_argument('--sample-rate', type=int, default=2048,
                        help='sample rate of the data')
    parser.add_argument('--ifar', type=float, help="false alarm rate (float)",
                        default=1)
    parser.add_argument('--thresh-SNR', type=float, default=4.5,
                        help='When `--trig-time` is given a single (approximate) time, only '
                             'detectors with SNR above this threshold are used '
                             'to find the precise trigger times. The others are still '
                             'used for sky localization, but do not determine '
                             'the trigger time.')
    parser.add_argument('--ifos', type=str, required=True, nargs='+',
                        help='List of interferometer names, e.g. H1 L1')
    parser.add_argument('--frame-type', type=str, nargs='+')
    parser.add_argument('--channel-name', type=str, nargs='+')
    parser.add_argument('--detector-state', type=str, nargs='+',
                        action=MultiDetMultiColonOptionAction,
                        help='Use the segment database to exclude detectors '
                             'with missing or vetoed data')
    parser.add_argument('--veto-definer', type=str,
                        help='Optional path to a veto definer file to resolve '
                             'macro flags in --detector-state')
    parser.add_argument('--segment-source', choices=['any', 'GWOSC', 'dqsegdb'],
                        default='any',
                        help='When using `--detector-state`, this controls '
                             'whether to query the GWOSC or dqsegdb server')
    parser.add_argument('--segment-server', metavar='URL',
                        help='URL of segment server to query when using '
                             '`--detector-state`',
                        default='https://segments.ligo.org')
    parser.add_argument('--ligolw-skymap-output', type=str, default='.',
                        help='Option to output sky map files to directory')
    parser.add_argument('--ligolw-event-output', type=str,
                        help='Option to keep coinc file under given name')
    parser.add_argument('--enable-production-gracedb-upload',
                        action='store_true', default=False,
                        help='Do not mark triggers uploaded to GraceDB as test '
                             'events. This option should *only* be enabled in '
                             'production analyses!')
    parser.add_argument('--gracedb-server', metavar='URL',
                        help='URL of GraceDB server API for uploading events.')
    parser.add_argument('--custom-frame-file', type=str, nargs='+',
                        action=MultiDetOptionAppendAction,
                        help='Lists of local frame files, e.g., '
                             'H1:/path/to/frame/file L1:/path/to/frame/file')
    parser.add_argument('--injection-file', type=str,
                        help='Optional path to an injection file.')
    parser.add_argument('--fake-strain', type=str, 
                        action=MultiDetOptionAction, metavar="IFO:FAKE_STRAIN",
                        help="The set of fake-strains for each detector")
    parser.add_argument('--fake-strain-seed', type=int,
                        action=MultiDetOptionAction, metavar="IFO:FAKE_STRAIN_SEED",
                        help="The set of fake-strains-seeds for each detector")
    parser.add_argument('--rescale-loglikelihood',
                        help="SNR rescaling factor used in BAYESTAR")

    opt = parser.parse_args()

    pycbc.init_logging(opt.verbose)

    frame_type_dict = {f.split(':')[0]: f.split(':')[1] for f in opt.frame_type} \
            if opt.frame_type is not None else None
    chan_name_dict = {f.split(':')[0]: f for f in opt.channel_name} \
            if opt.channel_name is not None else None

    main(opt.trig_time, opt.mass1, opt.mass2,
         opt.spin1z, opt.spin2z, opt.f_low, opt.f_upper, opt.sample_rate,
         opt.ifar, opt.ifos, opt.thresh_SNR, opt.ligolw_skymap_output,
         opt.ligolw_event_output,
         frame_types=frame_type_dict, channel_names=chan_name_dict,
         gracedb_server=opt.gracedb_server,
         test_event=not opt.enable_production_gracedb_upload,
         custom_frame_files=opt.custom_frame_file,
         approximant=opt.approximant, detector_state=opt.detector_state,
         segment_source=opt.segment_source, segment_server=opt.segment_server,
         veto_definer=opt.veto_definer, injection_file=opt.injection_file, 
         fake_strain=opt.fake_strain, fake_strain_seed=opt.fake_strain_seed,
         rescale_loglikelihood=opt.rescale_loglikelihood)
back to top