Revision d44b370e336fd20143c901199f90ab5e2aa19a46 authored by Alex Nitz on 30 April 2017, 15:47:25 UTC, committed by GitHub on 30 April 2017, 15:47:25 UTC
1 parent 29416a2
Raw File
pycbc_mvsc_get_features
#! /usr/bin/python
import sqlite3 
from pycbc_glue.ligolw import table, lsctables, dbtables
from pycbc import events
from pylal import git_version
from time import clock,time
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from argparse import *
import random
import itertools
import numpy as np
import fileinput

usage="""
This code constructs features using the properties of double coincident triggers. 
Single and coincident trigger features may be selected from the list of columns of 
single inspiral and coincident tables in the database supplied.

User specified features are constructed from sqlite databases and saved into .pat files for MVSC.
The columns of the .pat file correspond to the features that characterize double coincident triggers.
"""

time1 = time()

__author__ = "Kari Hodge <khodge@ligo.caltech.edu>, Shasvath J. Kapadia <skapadia@email.uark.edu>, Thomas Dent <thomas.dent@aei.mpg.de"

sqlite3.enable_callback_tracebacks(True)

############## User Defined Input ##############################
parser=ArgumentParser(usage=usage,version=git_version.verbose_msg)
parser.add_argument("--number", default=10, type=int, help="folds for cross-validation/number for round robin. Default 10")
parser.add_argument("--factor", default=50.0, type=float, help="the value of the magic number factor in the effective snr formula, should be 50 for highmass and 250 for lowmass. Default 50")
parser.add_argument("--instruments", help="pair of detectors. Example: H1,L1")
parser.add_argument("--output-tag", default="CBC", help="a string added to all filenames")
#parser.add_argument("--apply-weights", action="store_true", default=False, help="calculates weight for all found injections. Weights are saved in .pat file (all noise events get weight=1). If this option is not supplied, all (noise+signal) events get a weight of 1.")
parser.add_argument("--weighting-method", default=None, choices=['physical_distance', 'chirp_distance'], help = "Specify weighting method/function for assigning weights to triggers. If not given, then all (noise+signal) events get a weight of 1.")
parser.add_argument("--reference-dist", type=float, default=300, help="distance (Mpc) for normalizing injection weights when training. Default 300")
parser.add_argument("--min-weight", type=float, default = -1, help="If applied, sets the lower bound of the signal weights to be min-weight")
parser.add_argument("--distance-power", type=float, default = 1, help="'Raise injected distance or chirp distance to given power before calculating weights'")
parser.add_argument("--plot-weights-histogram", action="store_true", default=False, help="plots a histogram for weights assigned to triggers")
parser.add_argument("--plot-path", default = ".", dest="plot_path", help="Specify path to store weights-histogram. Default: current directory")
parser.add_argument("--plot-name", default = "weights_histogram.png", dest="plot_name", help="Specify name of histogram plot. Default: weights_histogram.png")
parser.add_argument("--exact-tag", default="ring_exact", help="this is the dbinjfind tag stored in the sqlite database for the exactly found injections - the ones you want to use for training")
parser.add_argument("-s","--tmp-space",help="necessary for sqlite calls, for example /local/user/shasvath.kapadia on Atlas")
parser.add_argument("--nearby-tag", default="ring_nearby", help="this is the dbinjfind tag stored in the sqlite for the nearby injections - we will still rank all of these")
parser.add_argument("--sngl-features", nargs="*", default=[], dest="sngl_features_list", help='Specify list of single trigger features to be extracted directly from sngl table. Features should be separated by spaces. Example: snr chisq')
parser.add_argument("--sngl-derived-features", nargs="*", default=[], dest="sngl_derived_features_list", choices=['reduced_chisq','reduced_cont_chisq','reduced_bank_chisq','effective_snr','bank_effective_snr','cont_effective_snr', 'newsnr'], help='Specify list of derived single features separated by spaces. Example: reduced_chisq effective_snr')
parser.add_argument("--coinc-features", nargs="*", default=[], dest="coinc_features_list", help='Specify list of features extracted from coinc_inspiral table, separated by spaces. Example: snr mass')
parser.add_argument("--difference-features", nargs="*", default=[], dest="difference_features_list", choices=['delta_t','delta_phi','eff_dist_ratio'], help='Specify list of difference features, separated by spaces. Example: delta_t delta_phi')
parser.add_argument("--input-database", dest="database")
parser.add_argument("--verbose", default=False, action="store_true")


## Function to put headers to feature files (.pat files)
def put_headers(filename, features_list):
	for line in fileinput.input([filename], inplace=True):
		if fileinput.isfirstline():
			print (str(len(features_list)) + '\n' + '\t'.join(features_list))
		print line,


################ Functions used by the trigger class #############

def reduced_chisq_func(chisq, chisq_dof):
	return chisq/(2*chisq_dof-2)

def reduced_cont_chisq_func(cont_chisq, cont_chisq_dof):
	if cont_chisq_dof == 0.:
		return 1
	else:
		return cont_chisq/cont_chisq_dof

def reduced_bank_chisq_func(bank_chisq, bank_chisq_dof):
	if bank_chisq_dof == 0.:
		return 1
	else:
		return bank_chisq/bank_chisq_dof

def mod2pi(phi):
	return np.mod(phi, 2*np.pi)

def chirp_dist_func(distance,mchirp):
	return lsctables.SnglInspiral.chirp_distance(distance, mchirp)

derived_feature_columns = ['snr','chisq','chisq_dof','bank_chisq','bank_chisq_dof','cont_chisq','cont_chisq_dof',
    'coa_phase','end_time','end_time_ns','offset','eff_distance']
#######################################################


################ Trigger class for all types of triggers (exact-inj, all-inj, timeslide, zerolag) #####################
class Trigger:

	## Class instance using datatype and signal/noise boolean
	def __init__(self, args, dtype, sn_bool, coinc_inspiral_time):
		self.fac = args.factor
		self.dtype = dtype
		self.sn_bool = sn_bool
		self.coinc_inspiral_time = coinc_inspiral_time

	## Single derived features:
	def calc_reduced_chisq(self):
		self.a_reduced_chisq = reduced_chisq_func(self.a_chisq, self.a_chisq_dof)
		self.b_reduced_chisq = reduced_chisq_func(self.b_chisq, self.b_chisq_dof)

	def calc_reduced_cont_chisq(self):
		self.a_reduced_cont_chisq = reduced_cont_chisq_func(self.a_cont_chisq, self.a_cont_chisq_dof)
		self.b_reduced_cont_chisq = reduced_cont_chisq_func(self.b_cont_chisq, self.b_cont_chisq_dof)

	def calc_reduced_bank_chisq(self):
		self.a_reduced_bank_chisq = reduced_bank_chisq_func(self.a_bank_chisq, self.a_bank_chisq_dof)
		self.b_reduced_bank_chisq = reduced_bank_chisq_func(self.b_bank_chisq, self.b_bank_chisq_dof)

	def calc_effective_snr(self):
		self.a_effective_snr = events.effsnr(self.a_snr, self.a_reduced_chisq, self.fac)
		self.b_effective_snr = events.effsnr(self.b_snr, self.b_reduced_chisq, self.fac)

	def calc_bank_effective_snr(self):
		self.a_bank_effective_snr = events.effsnr(self.a_snr, self.a_reduced_bank_chisq, self.fac)
		self.b_bank_effective_snr = events.effsnr(self.b_snr, self.b_reduced_bank_chisq, self.fac)

	def calc_cont_effective_snr(self):
		self.a_cont_effective_snr = events.effsnr(self.a_snr, self.a_reduced_cont_chisq, self.fac)
		self.b_cont_effective_snr = events.effsnr(self.b_snr, self.b_reduced_cont_chisq, self.fac)

	def calc_newsnr(self):
		self.a_newsnr = events.newsnr(self.a_snr, self.a_reduced_chisq)
		self.b_newsnr = events.newsnr(self.b_snr, self.b_reduced_chisq)

	##  Difference features:
	def calc_delta_t(self):
		self.delta_t = (lsctables.LIGOTimeGPS(self.a_end_time + self.a_offset, self.a_end_time_ns)
				-lsctables.LIGOTimeGPS(self.b_end_time + self.b_offset, self.b_end_time_ns))

	def calc_delta_phi(self):
		self.delta_phi = mod2pi(self.a_coa_phase-self.b_coa_phase)

	def calc_eff_dist_ratio(self):
		self.eff_dist_ratio = self.a_eff_distance/self.b_eff_distance

	##  Assign weight to signal triggers, if specified by user and if the injection
	##  distance distribution is recognized.

	def set_weight(self, weighting_method, min_weight, distance_power, reference_dist=300.):
		# default value for zerolag, time slides and unrecognized distributions
		self.weight = 1
		if (weighting_method is not None and self.sn_bool == 1):
			if weighting_method == "physical_distance":
				if self.d_distr == "uniform":
					self.weight = ((self.distance/reference_dist)**2)**distance_power
				elif self.d_distr == "log10":
					self.weight = ((self.distance/reference_dist)**3)**distance_power
				elif self.d_distr == "volume":
					self.weight = 1
			elif weighting_method == "chirp_distance":
				self.chirp_dist = chirp_dist_func(self.distance, self.sim_mchirp)
				if self.d_distr == "uniform":
					self.weight = ((self.chirp_dist/reference_dist)**2)**distance_power
				elif self.d_distr == "log10":
					self.weight = ((self.chirp_dist/reference_dist)**3)**distance_power
				elif self.d_distr == "volume":
					self.weight = 1
			if min_weight != -1:
				if self.weight < min_weight: self.weight = min_weight
					


#####################################################################

############### Strings for sqlite queries ##########################

select_dimensions_detA="""
	SELECT
		snglA.*,  
		tsA.offset """

select_dimensions_detB="""
	SELECT
		snglB.*, 
		tsB.offset"""

select_dimensions_coinc="""
	SELECT
		coinc_inspiral.*"""

select_dimensions_sim="""
	SELECT
		process_params.value,
		sim_inspiral.distance,
		sim_inspiral.mchirp"""

add_from_injections="""
	FROM
		coinc_inspiral
		JOIN coinc_event_map AS mapA ON (mapA.coinc_event_id == coinc_inspiral.coinc_event_id)
		JOIN coinc_event_map AS mapB ON (mapB.coinc_event_id == coinc_inspiral.coinc_event_id)
		JOIN sngl_inspiral AS snglA ON (snglA.event_id == mapA.event_id)
		JOIN sngl_inspiral AS snglB ON (snglB.event_id == mapB.event_id)
		JOIN coinc_event_map AS mapC ON (mapC.event_id == coinc_inspiral.coinc_event_id)
		JOIN coinc_event_map AS mapD ON (mapD.coinc_event_id == mapC.coinc_event_id)
		JOIN sim_inspiral ON (sim_inspiral.simulation_id == mapD.event_id)
		JOIN coinc_event AS sim_coinc_event ON (sim_coinc_event.coinc_event_id == mapD.coinc_event_id)
		JOIN coinc_event AS insp_coinc_event ON (insp_coinc_event.coinc_event_id == mapA.coinc_event_id)
		JOIN coinc_definer ON (coinc_definer.coinc_def_id == sim_coinc_event.coinc_def_id)
		JOIN process_params ON (process_params.process_id == sim_inspiral.process_id)
		JOIN time_slide AS tsA ON (tsA.time_slide_id=insp_coinc_event.time_slide_id)
		JOIN time_slide AS tsB ON (tsB.time_slide_id=insp_coinc_event.time_slide_id)
	WHERE
		mapA.table_name == 'sngl_inspiral'
		AND mapB.table_name == 'sngl_inspiral'
		AND mapC.table_name == 'coinc_event'
		AND mapD.table_name == 'sim_inspiral'
		AND tsA.instrument == snglA.ifo
		AND tsB.instrument == snglB.ifo
		AND snglA.ifo == ?
		AND snglB.ifo == ?
		AND process_params.program == 'inspinj'
		AND (process_params.param == '--d-distr' OR process_params.param == '--dchirp-distr')"""

add_where_injections="""
		AND coinc_definer.description == ?
		ORDER BY coinc_inspiral.end_time+coinc_inspiral.end_time_ns*.000000001"""

add_select_fulldata="""
		, experiment_summary.datatype"""

add_from_fulldata="""
	FROM
		coinc_inspiral
		JOIN coinc_event_map AS mapA ON (mapA.coinc_event_id == coinc_inspiral.coinc_event_id)
		JOIN coinc_event_map AS mapB ON (mapB.coinc_event_id == coinc_inspiral.coinc_event_id)
		JOIN sngl_inspiral AS snglA ON (snglA.event_id == mapA.event_id)
		JOIN sngl_inspiral AS snglB ON (snglB.event_id == mapB.event_id)
		JOIN coinc_event AS insp_coinc_event ON (mapA.coinc_event_id == insp_coinc_event.coinc_event_id)
		JOIN coinc_definer ON (coinc_definer.coinc_def_id == insp_coinc_event.coinc_def_id)
		JOIN experiment_map ON (experiment_map.coinc_event_id == coinc_inspiral.coinc_event_id)
		JOIN experiment_summary ON (experiment_summary.experiment_summ_id == experiment_map.experiment_summ_id)
		JOIN time_slide AS tsA ON (tsA.time_slide_id=insp_coinc_event.time_slide_id)
		JOIN time_slide AS tsB ON (tsB.time_slide_id=insp_coinc_event.time_slide_id)
	WHERE
		coinc_definer.search == 'inspiral'
		AND coinc_definer.search_coinc_type == 0
		AND mapA.table_name == 'sngl_inspiral'
		AND mapB.table_name == 'sngl_inspiral'
		AND tsA.instrument == snglA.ifo
		AND tsB.instrument == snglB.ifo
		AND snglA.ifo == ?
		AND snglB.ifo == ?
		ORDER BY coinc_inspiral.coinc_event_id"""


acquire_injections = add_from_injections + add_where_injections
acquire_all_data = add_select_fulldata + add_from_fulldata

# String-query to acquire injections
injections_detA_string = select_dimensions_detA + acquire_injections
injections_detB_string = select_dimensions_detB + acquire_injections
injections_coinc_string = select_dimensions_coinc + acquire_injections
injections_sim_string = select_dimensions_sim + acquire_injections

# String-query to acquire all data
all_data_detA_string = select_dimensions_detA + acquire_all_data
all_data_detB_string = select_dimensions_detB + acquire_all_data
all_data_coinc_string = select_dimensions_coinc + acquire_all_data

#######################################################################

## Define Dictionaries ######

all_detA_trigs_dict = {}
all_detB_trigs_dict = {}
all_coinc_trigs_dict = {}
all_sim_trigs_dict = {}
all_trigs_dict = {}

############ MAIN CODE STARTS HERE ###############

args = parser.parse_args()
ifos=args.instruments.strip().split(',')
ifos.sort()

allfeatures = (args.sngl_features_list + args.sngl_derived_features_list + args.coinc_features_list + args.difference_features_list) 
if len(allfeatures) == 0:
	print "ERROR: Please supply features!"
	exit(0)

print "### Getting information from database ..."


# All space on local disk?
local_disk = args.tmp_space
# Acquire name of database
working_filename = dbtables.get_connection_filename(args.database, tmp_path = local_disk, verbose = args.verbose)
# Set up connection with database
con = sqlite3.connect(working_filename)
# Enable dictionary cursor
con.row_factory = sqlite3.Row
# Don't know what this is for yet
xmldoc = dbtables.get_xml(con)
# Create cursor
cur = con.cursor()

# in S6, the timeslides, zerolag, and injections are all stored in the same sqlite database; 
# thus this database must include a sim inspiral table; 
# if you provide a database that does not include injections, 
# the code will still run as long as one of the databases you provide includes injections 

try:
	sim_inspiral_table = table.get_table(xmldoc, lsctables.SimInspiralTable.tableName)
	is_injections = True
except ValueError:
	is_injections = False

if is_injections:

	# Get nearby injections from database: store data for each detector in separate lists of dictionaries
	cur.execute(injections_detA_string, (ifos[0],ifos[1],args.nearby_tag))
	all_detA_trigs_dict["all_injections"] = cur.fetchall()
	cur.execute(injections_detB_string, (ifos[0],ifos[1],args.nearby_tag))
	all_detB_trigs_dict["all_injections"] = cur.fetchall()
	cur.execute(injections_coinc_string, (ifos[0],ifos[1],args.nearby_tag))
	all_coinc_trigs_dict["all_injections"] = cur.fetchall()
	cur.execute(injections_sim_string, (ifos[0],ifos[1],args.nearby_tag))
	all_sim_trigs_dict["all_injections"] = cur.fetchall()

	# Get exact injections from database: store data for each detector in separate lists of dictionaries
	cur.execute(injections_detA_string, (ifos[0],ifos[1],args.exact_tag))
	all_detA_trigs_dict["exact_injections"] = cur.fetchall()
	cur.execute(injections_detB_string, (ifos[0],ifos[1],args.exact_tag))
	all_detB_trigs_dict["exact_injections"] = cur.fetchall()
	cur.execute(injections_coinc_string, (ifos[0],ifos[1],args.exact_tag))
	all_coinc_trigs_dict["exact_injections"] = cur.fetchall()
	cur.execute(injections_sim_string, (ifos[0],ifos[1],args.exact_tag))
        all_sim_trigs_dict["exact_injections"] = cur.fetchall()

	# Get all data from database: store data for each detector in separate lists of dictionaries
	cur.execute(all_data_detA_string,(ifos[0],ifos[1]))
	all_data_detA = cur.fetchall()
	cur.execute(all_data_detB_string,(ifos[0],ifos[1]))
	all_data_detB = cur.fetchall()
	cur.execute(all_data_coinc_string,(ifos[0],ifos[1]))
	all_data_coinc = cur.fetchall()

dbtables.discard_connection_filename(args.database, working_filename, verbose = args.verbose)

# Separate data into timeslide and zerolag  
all_detA_trigs_dict["timeslides"] = [data_detA for data_detA in all_data_detA if data_detA["datatype"] == "slide"]
all_detB_trigs_dict["timeslides"] = [data_detB for data_detB in all_data_detB if data_detB["datatype"] == "slide"]
all_coinc_trigs_dict["timeslides"] = [data_coinc for data_coinc in all_data_coinc if data_coinc["datatype"] == "slide"]

all_detA_trigs_dict["zerolags"] = [data_detA for data_detA in all_data_detA if data_detA["datatype"] in ["playground", "exclude_play"]]
all_detB_trigs_dict["zerolags"] = [data_detB for data_detB in all_data_detB if data_detB["datatype"] in ["playground", "exclude_play"]]
all_coinc_trigs_dict["zerolags"] = [data_coinc for data_coinc in all_data_coinc if data_coinc["datatype"] in ["playground", "exclude_play"]]

# Fill sim info with dummy values
all_sim_trigs_dict["timeslides"] = [{} for d in all_coinc_trigs_dict["timeslides"]]
all_sim_trigs_dict["zerolags"] = [{} for d in all_coinc_trigs_dict["zerolags"]]

print "### Done!"



datatypes = ["all_injections", "exact_injections", "timeslides", "zerolags"]
sn_bool_list = [1,1,0,0]


######### Set attributes of trigger class objects. Attributes acquired/constructed from database columns. ######
print "### Constructing features, saving as trigger objects..."
for dtype, sn_bool in zip(datatypes, sn_bool_list):
	
	# Create temporary trigger list
	trig_tmp_list = []

	# determine list of sngl_inspiral columns in advance
	# restrict this to the ones that may possibly be used for features
	sngl_columns = derived_feature_columns + args.sngl_features_list
	for trig_detA, trig_detB, trig_coinc, trig_sim in zip(
			all_detA_trigs_dict[dtype], all_detB_trigs_dict[dtype], all_coinc_trigs_dict[dtype], all_sim_trigs_dict[dtype]):
		
		# Exact coinc inspiral time: required to create instance of trigger class
		coinc_inspiral_time = lsctables.LIGOTimeGPS(trig_coinc["end_time"], trig_coinc["end_time_ns"])

		# Create instance of trigger class object with datatype, signal/noise boolean, coinc inspiral time as attributes
		trig = Trigger(args,dtype,sn_bool,coinc_inspiral_time)	

		######### All sngl features for detectors A and B ######

		for feature in sngl_columns:
			setattr(trig, "a_" + feature, trig_detA[feature])
			setattr(trig, "b_" + feature, trig_detB[feature])

		##### User specified coinc features. ###########
		for feature in args.coinc_features_list:
			setattr(trig, "coinc_" + feature, trig_coinc[feature])

		#### Coinc event id ####
		setattr(trig, "coinc_event_id", trig_coinc["coinc_event_id"])

		#### Calculate reduced chisq values for use in other features ####
		for chisqtype in ["", "bank_", "cont_"]:
			eval("trig." + "calc_reduced_" + chisqtype + "chisq()")

		#### User specified single Derived features ####
		for feature in args.sngl_derived_features_list:
			if not ("reduced" in feature):
				eval("trig." + "calc_" + feature + "()")

		#### User specified difference features #####
		for feature in args.difference_features_list:
			eval("trig." + "calc_" + feature + "()")

		### Weights for signal triggers ######
		if sn_bool == 1:
			trig.distance = trig_sim["distance"]
			trig.d_distr = trig_sim["value"]
			trig.sim_mchirp = trig_sim["mchirp"]
		else:
			trig.distance = -1  # dummy value for non-injection coincs
			trig.d_distr = None
		trig.set_weight(args.weighting_method, args.min_weight, args.distance_power, args.reference_dist)

		# Append triggers to temporary list
		trig_tmp_list.append(trig)

	all_trigs_dict[dtype] = trig_tmp_list

print "### Done!"

########### Plot Histogram for Weights ###############

if args.plot_weights_histogram and args.weighting_method is not None:
	
	print "Plotting all found injection weights histogram ..."
	signal_weights_list = [trig.weight for trig in all_trigs_dict["all_injections"]]
	noise_weights_list = [trig.weight for trig in all_trigs_dict["timeslides"]]
	num_bins = 30

	########## Weights Histogram ############
        plt.figure()
        plt.rc('text', usetex=True)
        plt.ylabel("Number of Found Injections")
        plt.title(args.weighting_method.replace("_"," ")) 
        x_min_signal = min(signal_weights_list)
        x_max_signal = max(signal_weights_list)
	plt.rc('text', usetex=True)
	plt.xlabel("$\log_{\mathrm{10}}$ w")
	bins_signal = np.logspace(np.log10(x_min_signal), np.log10(x_max_signal), num_bins)
	bins_noise = np.logspace(-1,1,num_bins)
	plt.hist(noise_weights_list, bins = bins_noise, normed=False, facecolor='b', label = "Noise", alpha = 0.5)
	plt.hist(signal_weights_list, bins = bins_signal, normed = False, facecolor='r', label = "Signal", alpha = 0.5)
	plt.xscale('log')
	plt.legend()
        plt.yscale('log', nonposy='clip')
        plt.savefig("%s/%s"%(args.plot_path, "w_" + args.plot_name))
        plt.close()

	######### Chirp Distance Histogram ########
	chirp_dist_list = [trig.chirp_dist for trig in all_trigs_dict["all_injections"]]
        plt.figure()
        plt.rc('text', usetex=True)
        plt.ylabel("Number of Found Injections")
        plt.title(args.weighting_method.replace("_"," "))
        x_min_signal = min(chirp_dist_list)
	x_max_signal = max(chirp_dist_list)
        plt.rc('text', usetex=True)
        plt.xlabel("$\log_{\mathrm{10}}$ $d_{\mathrm{chirp}}$")
        bins_signal = np.logspace(np.log10(x_min_signal), np.log10(x_max_signal), num_bins)
        plt.hist(chirp_dist_list, bins = bins_signal, normed = False, facecolor='r', label = "Signal", alpha = 0.5)
        plt.xscale('log')
        plt.legend()
        plt.yscale('log', nonposy='clip')
        plt.savefig("%s/%s"%(args.plot_path, "chirp-dist_" + args.plot_name))
        plt.close()

	######## Physical Distance Histogram #######
        dist_list = [trig.distance for trig in all_trigs_dict["all_injections"]]
        plt.figure()
        plt.rc('text', usetex=True)
        plt.ylabel("Number of Found Injections")
        plt.title(args.weighting_method.replace("_"," "))
        x_min_signal = min(dist_list)
	x_max_signal = max(dist_list)
        plt.rc('text', usetex=True)
        plt.xlabel("$\log_{\mathrm{10}}$ $d$")
        bins_signal = np.logspace(np.log10(x_min_signal), np.log10(x_max_signal), num_bins)
        plt.hist(dist_list, bins = bins_signal, normed = False, facecolor='r', label = "Signal", alpha = 0.5)
        plt.xscale('log')
        plt.legend()
        plt.yscale('log', nonposy='clip')
        plt.savefig("%s/%s"%(args.plot_path, "dist_" + args.plot_name))
        plt.close()

        print "Done!"


#####################################################

#### Merge features lists with appropriate tags for features #####
features_list = []

for feature in args.sngl_features_list:
	features_list.append("a_" + feature)
	features_list.append("b_" + feature)

for feature in args.coinc_features_list:
	features_list.append("coinc_" + feature)

for feature in args.sngl_derived_features_list:
	features_list.append("a_" + feature)
	features_list.append("b_" + feature)

features_list = features_list + args.difference_features_list
output_columns = features_list + ["weight", "sn_bool"]
# Format to save features in .pat files
# Formats include columns for weights and sn_bool, which are not features
formats = ["%1.6e"]*(len(features_list)+1) + ["%d"]


############################ K-fold Cross Validation ########################################
## OUTLINE: K-fold cross validation involves splitting signal and noise data into K partitions.
# One partition of each data type (signal or noise) is assigned for classifier testing/evaluation. The remaining
# partitions are assigned for classifier training. 

K = args.number
print "### Starting %d-fold Cross Validation ..."%K
# Make K partitions/folds for signal triggers --- testing and training
exact_injections_partitions = np.array_split(all_trigs_dict["exact_injections"], K)

# Determine GPS lower-boundaries for all_injection partitions
gps_lower_bounds = ([all_trigs_dict["all_injections"][0].coinc_inspiral_time] 
				+ [partition[0].coinc_inspiral_time for partition in exact_injections_partitions[1:]])

# Determine GPS upper-boundaries for all_injection partitions
# Add 1 to the highest upper bound as we will be using a 'strictly less than gps_upper' condition
gps_upper_bounds = gps_lower_bounds[1:] + [all_trigs_dict["all_injections"][-1].coinc_inspiral_time + 1]

# Do we really need to shuffle?
random.seed(2)
random.shuffle(all_trigs_dict["timeslides"])

# Make K partitions/folds for noise triggers --- testing and training
timeslides_partitions = np.array_split(all_trigs_dict["timeslides"], K)


################ Save testing and training signal data in folds: to be used later for cross-validation ###########

for i, (gps_lower, gps_upper, ts_partition) in enumerate(zip(gps_lower_bounds, gps_upper_bounds, timeslides_partitions)):

	############# Acquire and save test trigger features ###########

	print "### Acquiring test trigger features for round #%d in round-robin ..."%i
	# Signal triggers for testing (taken from all_injections, which include nearby + exact injections)
	all_injections_testing = [trig for trig in all_trigs_dict["all_injections"] 
						if trig.coinc_inspiral_time >= gps_lower and trig.coinc_inspiral_time < gps_upper]

	# Noise triggers for testing (taken from timeslides, which does NOT include zerolag triggers)
	timeslides_testing = ts_partition.tolist()

	# Merge signal and noise triggers for testing
	triggers_testing = all_injections_testing + timeslides_testing 
	
	# Extract testing trigger object attributes, to be subsequently used as features
	trigger_testing_features = tuple([[getattr(trig,c) for c in output_columns] for trig in triggers_testing])

	# Save testing trigger features in .pat files tagged with a partition/set number tag
	filename = ''.join(ifos) + '_' + args.output_tag + '_set' + str(i) + '_' + 'evaluation' + '.pat' 
	np.savetxt(filename, trigger_testing_features, fmt=formats, delimiter='\t')
	# Add headers and number of features to file
	put_headers(filename, features_list)

	# Extract testing trigger object attributes, to be subsequently used as features
	triggers_testing_info = tuple([[trig.coinc_event_id, args.database] for trig in triggers_testing])

	# Save testing trigger features in .pat files tagged with a partition/set number tag
	filename = ''.join(ifos) + '_' + args.output_tag + '_set' + str(i) + '_' + 'evaluation_info' + '.pat'
	np.savetxt(filename, triggers_testing_info, fmt="%s", delimiter='\t')

	###############################################################
	
	############# Acquire and save training trigger features ###########

	print "### Acquiring training trigger features for round #%d in round-robin ..."%i
	# Signal triggers for training (taken from exact_injections, which consists of ONLY exact injections (duh))
	exact_injections_training = [trig for trig in list(itertools.chain.from_iterable(
								exact_injections_partitions[:i]+exact_injections_partitions[i+1:]))]	

	# Noise triggers for training (taken from timeslides, which does NOT contain zerolags)
	timeslides_training = [trig for trig in list(itertools.chain.from_iterable(timeslides_partitions[:i]+timeslides_partitions[i+1:]))]

	# Merge signal and noise triggers for training
	triggers_training = exact_injections_training + timeslides_training

	# Extract training trigger object attributes, to be subsequently used as features
	trigger_training_features = tuple([[getattr(trig,c) for c in output_columns] for trig in triggers_training])

	# Save training trigger features in .pat files tagged with a partition/set number tag
	filename = ''.join(ifos) + '_' + args.output_tag + '_set' + str(i) + '_' + "training" + '.pat'
	np.savetxt(filename, trigger_training_features, fmt=formats, delimiter='\t')
	# Add headers and number of features to file
	put_headers(filename, features_list)

	################################################################

	#### Save zerolag trigger features in one file. Remark: Hodge has 10 copies of the same zerolag file. Is this necessary? ##########
	print "### Saving zerolag trigger features for round #%d in round-robin..."%i
	zerolag_trigger_features = tuple([[getattr(trig,c) for c in output_columns] for trig in all_trigs_dict["zerolags"]])

	# Save the trigger features in .pat files tagged with a partition/set number tag
	filename = ''.join(ifos) + '_' + args.output_tag + '_set' + str(i) + '_' + "zerolag" + '.pat'
	np.savetxt(filename, zerolag_trigger_features, fmt=formats, delimiter='\t')
	# Add headers and number of features to file
	put_headers(filename, features_list)

	# Extract testing trigger object attributes, to be subsequently used as features
	zerolag_trigger_info = tuple([[trig.coinc_event_id, args.database] for trig in all_trigs_dict["zerolags"]])

	# Save zerolag trigger features in .pat files tagged with a partition/set number tag
	filename = ''.join(ifos) + '_' + args.output_tag + '_set' + str(i) + '_' + 'zerolag_info' + '.pat'
	np.savetxt(filename, zerolag_trigger_info, fmt="%s", delimiter='\t')


print "### Done!"

time2 = time()
elapsed_time = time2-time1
print "Run-time = %f seconds"%elapsed_time

		


back to top