Source code for svd_bank

# Copyright (C) 2010  Kipp Cannon, Chad Hanna, Leo Singer
# Copyright (C) 2009  Kipp Cannon, Chad Hanna
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

## @file
# The module to implement SVD decomposition of CBC waveforms
#
# ### Review Status
#
# | Names                                          | Hash                                        | Date       | Diff to Head of Master      |
# | -------------------------------------------    | ------------------------------------------- | ---------- | --------------------------- |
# | Florent, Sathya, Duncan Me, Jolien, Kipp, Chad | 7536db9d496be9a014559f4e273e1e856047bf71    | 2014-04-30 | <a href="@gstlal_inspiral_cgit_diff/python/svd_bank.py?id=HEAD&id2=7536db9d496be9a014559f4e273e1e856047bf71">svd_bank.py</a> |
#
# #### Actions
# - Consider a study of how to supply the svd / time slice boundaries
#

## @package svd_bank


#
# =============================================================================
#
#				   Preamble
#
# =============================================================================
#


import copy
import numpy
import os
import sys
import warnings
import scipy

import lal

from ligo.lw import ligolw
from ligo.lw import lsctables
from ligo.lw import array as ligolw_array
from ligo.lw import param as ligolw_param
from ligo.lw import utils as ligolw_utils
from ligo.lw.utils import process as ligolw_process

Attributes = ligolw.sax.xmlreader.AttributesImpl

from gstlal import cbc_template_fir
from gstlal import misc as gstlalmisc
from gstlal import templates
from gstlal.psd import condition_psd, HorizonDistance

from gstlal import chirptime
from gstlal import spawaveform

[docs]class DefaultContentHandler(ligolw.LIGOLWContentHandler): pass
ligolw_array.use_in(DefaultContentHandler) ligolw_param.use_in(DefaultContentHandler) lsctables.use_in(DefaultContentHandler) # # ============================================================================= # # Utilities # # ============================================================================= # # # Read approximant #
[docs]def read_approximant(xmldoc, programs = ("gstlal_bank_splitter",)): process_ids = set() for program in programs: process_ids |= lsctables.ProcessTable.get_table(xmldoc).get_ids_by_program(program) if not process_ids: raise ValueError("document must contain process entries from %s" % ", ".join(programs)) approximant = set(row.pyvalue for row in lsctables.ProcessParamsTable.get_table(xmldoc) if (row.process_id in process_ids) and (row.param == "--approximant")) if not approximant: raise ValueError("document must contain an 'approximant' process_params entry from %s" % ", ".join("'%s'" for program in programs)) if len(approximant) > 1: raise ValueError("document must contain only one approximant") approximant = approximant.pop() templates.gstlal_valid_approximant(approximant) return approximant
# # check final frequency is populated and return the max final frequency #
[docs]def check_ffinal_and_find_max_ffinal(xmldoc): f_final = lsctables.SnglInspiralTable.get_table(xmldoc).getColumnByName("f_final") if not all(f_final): raise ValueError("f_final column not populated") return max(f_final)
# # sum-of-squares false alarm probability #
[docs]def sum_of_squares_threshold_from_fap(fap, coefficients): return gstlalmisc.max_stat_thresh(coefficients, fap)
#return gstlalmisc.cdf_weighted_chisq_Pinv(coefficients, numpy.zeros(coefficients.shape, dtype = "double"), numpy.ones(coefficients.shape, dtype = "int"), 0.0, 1.0 - fap, -1, fap / 16.0)
[docs]def group(inlist, parts): """! group a list roughly according to the distribution in parts, e.g. >>> A = list(range(12)) >>> B = [2,3] >>> for g in group(A,B): ... print(g) ... [0, 1] [2, 3] [4, 5] [6, 7, 8] [9, 10, 11] """ mult_factor = len(inlist) // sum(parts) + 1 l = copy.deepcopy(inlist) for i, p in enumerate(parts): for j in range(mult_factor): if not l: break yield l[:p] del l[:p]
# # ============================================================================= # # Pipeline Metadata # # ============================================================================= #
[docs]class BankFragment(object): def __init__(self, rate, start, end): self.rate = rate self.start = start self.end = end
[docs] def set_template_bank(self, template_bank, tolerance, snr_thresh, identity_transform = False, verbose = False): if verbose: print("\t%d templates of %d samples" % template_bank.shape, file=sys.stderr) self.orthogonal_template_bank, self.singular_values, self.mix_matrix, self.chifacs = cbc_template_fir.decompose_templates(template_bank, tolerance, identity = identity_transform) if self.singular_values is not None: self.sum_of_squares_weights = numpy.sqrt(self.chifacs.mean() * gstlalmisc.ss_coeffs(self.singular_values,snr_thresh)) else: self.sum_of_squares_weights = None if verbose: print("\tidentified %d components" % self.orthogonal_template_bank.shape[0], file=sys.stderr) print("\tsum-of-squares expectation value is %g" % self.chifacs.mean(), file=sys.stderr)
[docs]class Bank(object): def __init__(self, bank_xmldoc, psd, time_slices, gate_fap, snr_threshold, tolerance, clipleft = None, clipright = None, flow = 40.0, autocorrelation_length = None, logname = None, identity_transform = False, bank_type = "signal_model", verbose = False, bank_id = None, fhigh = None): # FIXME: remove template_bank_filename when no longer needed # by trigger generator element self.template_bank_filename = None self.filter_length = time_slices['end'].max() self.snr_threshold = snr_threshold self.bank_type = bank_type if logname is not None and not logname: raise ValueError("logname cannot be empty if it is set") self.logname = logname self.bank_id = bank_id # Generate downsampled templates template_bank, self.autocorrelation_bank, self.autocorrelation_mask, self.sigmasq, bank_workspace = cbc_template_fir.generate_templates( lsctables.SnglInspiralTable.get_table(bank_xmldoc), read_approximant(bank_xmldoc), psd, flow, time_slices, autocorrelation_length = autocorrelation_length, fhigh = fhigh, time_reverse = bank_type == "noise_model", verbose = verbose) # Include signal inspiral table sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(bank_xmldoc) self.sngl_inspiral_table = sngl_inspiral_table.copy() self.sngl_inspiral_table.extend(sngl_inspiral_table) # Include the processed psd self.processed_psd = bank_workspace.psd # Include some parameters passed from the bank workspace self.newdeltaF = 1. / bank_workspace.working_duration self.working_f_low = bank_workspace.working_f_low self.f_low = bank_workspace.f_low self.sample_rate_max = bank_workspace.sample_rate_max # Assign template banks to fragments self.bank_fragments = [BankFragment(rate,begin,end) for rate,begin,end in bank_workspace.time_slices] # Setup bank correlation matrix self.bank_correlation_matrix = None for i, bank_fragment in enumerate(self.bank_fragments): if verbose: print("constructing template decomposition %d of %d: %g s ... %g s" % (i + 1, len(self.bank_fragments), -bank_fragment.end, -bank_fragment.start), file=sys.stderr) bank_fragment.set_template_bank(template_bank[i], tolerance, self.snr_threshold, identity_transform = identity_transform, verbose = verbose) cmix = bank_fragment.mix_matrix[:,::2] + 1.j * bank_fragment.mix_matrix[:,1::2] if self.bank_correlation_matrix is None: self.bank_correlation_matrix = numpy.dot(numpy.conj(cmix.T), cmix) else: self.bank_correlation_matrix += numpy.dot(numpy.conj(cmix.T), cmix) if bank_fragment.sum_of_squares_weights is not None: self.gate_threshold = sum_of_squares_threshold_from_fap(gate_fap, numpy.array([weight**2 for bank_fragment in self.bank_fragments for weight in bank_fragment.sum_of_squares_weights], dtype = "double")) else: self.gate_threshold = 0. if verbose: print("sum-of-squares threshold for false-alarm probability of %.16g: %.16g" % (gate_fap, self.gate_threshold), file=sys.stderr) # Sanity checks before cliping clipright = len(self.sngl_inspiral_table) - clipright if clipright is not None else None doubled_clipright = clipright * 2 if clipright is not None else None doubled_clipleft = clipleft * 2 if clipleft is not None else None # Apply clipping options new_sngl_table = self.sngl_inspiral_table.copy() for row in self.sngl_inspiral_table[clipleft:clipright]: # FIXME need a proper id column row.Gamma1 = int(self.bank_id.split("_")[0]) new_sngl_table.append(row) self.sngl_inspiral_table = new_sngl_table self.autocorrelation_bank = self.autocorrelation_bank[clipleft:clipright,:] self.autocorrelation_mask = self.autocorrelation_mask[clipleft:clipright,:] self.sigmasq = self.sigmasq[clipleft:clipright] self.bank_correlation_matrix = self.bank_correlation_matrix[clipleft:clipright,clipleft:clipright] for i, frag in enumerate(self.bank_fragments): if frag.mix_matrix is not None: frag.mix_matrix = frag.mix_matrix[:,doubled_clipleft:doubled_clipright] frag.chifacs = frag.chifacs[doubled_clipleft:doubled_clipright]
[docs] def get_rates(self): return set(bank_fragment.rate for bank_fragment in self.bank_fragments)
# FIXME: remove set_template_bank_filename when no longer needed # by trigger generator element
[docs] def set_template_bank_filename(self,name): self.template_bank_filename = name
[docs]def cal_higher_f_low(template_bank_url, bank_xmldoc, flow, max_duration): """ This function takes a target low frequency starting point, flow, and a maximum time duration, max_duration, and determines if the flow needs to be increased in order to satisfy the target maximum duration. It also overwrites the template_duration in Single Inspiral Table if the flow is increased. return the higher value of lower frequency bound between f_low and time_constrained_f_low @param template_bank_url The template bank filename or url containing a subbank of templates to decompose in a single inpsiral table. @param flow The lower frequency cutoff. @param max_duration The maximum time duration of waveform to set the lower frequency cutoff. @param fhigh The maximum frequency cutoff """ def time_freq_bound(flow, max_duration, m1, m2, j1, j2, f_max): """ To find the root of the function (flow) """ return chirptime.imr_time(f = flow, m1 = m1, m2 = m2, j1 = j1, j2 = j2, f_max = f_max) - max_duration # Get sngl inspiral table bank_sngl_table = lsctables.SnglInspiralTable.get_table(bank_xmldoc) f_high = check_ffinal_and_find_max_ffinal(bank_xmldoc) # maximum frequency cut off approximant, = ligolw_process.get_process_params(bank_xmldoc, 'gstlal_bank_splitter', '--approximant') time_constrained_f_low = [] for row in bank_sngl_table: m1_SI = lal.MSUN_SI * row.mass1 m2_SI = lal.MSUN_SI * row.mass2 spin1 = numpy.dot(row.spin1, row.spin1)**.5 spin2 = numpy.dot(row.spin2, row.spin2)**.5 f_max = min(row.f_final, 2 * chirptime.ringf(lal.MSUN_SI * row.mass1 + lal.MSUN_SI * row.mass2, chirptime.overestimate_j_from_chi(max(spin1, spin2))) if approximant in templates.gstlal_IMR_approximants else spawaveform.ffinal(row.mass1, row.mass2, 'bkl_isco')) time_constrained_f_low.append(scipy.optimize.fsolve(time_freq_bound, x0 = flow, args = (max_duration, m1_SI, m2_SI, spin1, spin2, f_max))) f_low = float(max(flow, max(time_constrained_f_low))) if f_high is not None and f_high < f_low: raise ValueError("Lower frequency must be lower than higher frequency cut off! Input max_duration is too short.") return f_low
[docs]def build_bank(template_bank_url, psd, flow, max_duration, ortho_gate_fap, snr_threshold, svd_tolerance, clipleft = None, clipright = None, padding = 1.5, identity_transform = False, bank_type = "signal_model", verbose = False, autocorrelation_length = 201, samples_min = 1024, samples_max_256 = 1024, samples_max_64 = 2048, samples_max = 4096, bank_id = None, contenthandler = None, sample_rate = None, instrument_override = None): """! Return an instance of a Bank class. @param template_bank_url The template bank filename or url containing a subbank of templates to decompose in a single inpsiral table. @param psd A class instance of a psd. @param flow The lower frequency cutoff. @param ortho_gate_fap The FAP threshold for the sum of squares threshold, see http://arxiv.org/abs/1101.0584 @param snr_threshold The SNR threshold for the search @param svd_tolerance The target SNR loss of the SVD, see http://arxiv.org/abs/1005.0012 @param clipleft The number of N poorly reconstructed templates from the left edge of each sub-bank to be removed @param cliptright The number of N poorly reconstructed templates from the right edge of each sub-bank to be removed @param padding The padding from Nyquist for any template time slice, e.g., if a time slice has a Nyquist of 256 Hz and the padding is set to 2, only allow the template frequency to extend to 128 Hz. @param identity_transform Don't do the SVD, just do time slices and keep the raw waveforms @param bank_type Define the type of the template bank, is it for producing signal candidates ("signal_model") or for producing noise candidates ("noise_model") @param verbose Be verbose @param autocorrelation_length The number of autocorrelation samples to use in the chisquared test. Must be odd @param samples_min The minimum number of samples to use in any time slice @param samples_max_256 The maximum number of samples to have in any time slice greater than or equal to 256 Hz @param samples_max_64 The maximum number of samples to have in any time slice greater than or equal to 64 Hz @param samples_max The maximum number of samples in any time slice below 64 Hz @param bank_id The id of the bank in question @param contenthandler The ligolw content handler for file I/O @param max_duration The maximum time duration of waveform to set the lower frequency cutoff """ # Open template bank file bank_xmldoc = ligolw_utils.load_url(template_bank_url, contenthandler = contenthandler, verbose = verbose) # Get sngl inspiral table bank_sngl_table = lsctables.SnglInspiralTable.get_table(bank_xmldoc) # override instrument if needed (this is useful if a generic instrument independent bank file is provided if instrument_override is not None: for row in bank_sngl_table: row.ifo = instrument_override # use "search" to indicate that the filter (template) # is for noise model estimation and identify which # background collector to use # FIXME: create new column for this information, using # "search" is only because it's not being used for row in bank_sngl_table: row.search = bank_type # Choose how to break up templates in time time_freq_bounds = templates.time_slices( bank_sngl_table, fhigh = check_ffinal_and_find_max_ffinal(bank_xmldoc), flow = cal_higher_f_low(template_bank_url, bank_xmldoc, flow, max_duration), padding = padding, samples_min = samples_min, samples_max_256 = samples_max_256, samples_max_64 = samples_max_64, samples_max = samples_max, sample_rate = sample_rate, verbose=verbose) if sample_rate is not None: fhigh=check_ffinal_and_find_max_ffinal(bank_xmldoc) else: fhigh=None # Generate templates, perform SVD, get orthogonal basis # and store as Bank object bank = Bank( bank_xmldoc, psd[bank_sngl_table[0].ifo], time_freq_bounds, gate_fap = ortho_gate_fap, snr_threshold = snr_threshold, tolerance = svd_tolerance, clipleft = clipleft, clipright = clipright, flow = cal_higher_f_low(template_bank_url, bank_xmldoc, flow, max_duration), autocorrelation_length = autocorrelation_length, # samples identity_transform = identity_transform, bank_type = bank_type, verbose = verbose, bank_id = bank_id, fhigh = fhigh ) # FIXME: remove this when no longer needed # by trigger generator element. bank.set_template_bank_filename(ligolw_utils.local_path_from_url(template_bank_url)) return bank
[docs]def write_bank(filename, banks, psd_input, process_param_dict = None, verbose = False): """Write SVD banks to a LIGO_LW xml file.""" # Create new document xmldoc = ligolw.Document() lw = xmldoc.appendChild(ligolw.LIGO_LW()) # add Process Parameter Table into svd_bank file if process_param_dict: process = ligolw_process.register_to_xmldoc(xmldoc, program="gstlal_inspiral_svd_bank", paramdict=process_param_dict, comment="Process parameter tables for further calculation") for bank in banks: # set up root for this sub bank root = lw.appendChild(ligolw.LIGO_LW(Attributes({u"Name": u"gstlal_svd_bank_Bank"}))) for row in bank.sngl_inspiral_table: row.template_duration = bank.bank_fragments[-1].end # make non-signal model templates have an invalid template id if bank.bank_type != "signal_model": row.template_id = -row.template_id if verbose: print("computing lambda/eta parameters for templates...") for row, auto_correlation in zip(bank.sngl_inspiral_table, bank.autocorrelation_bank): row.Gamma2, row.Gamma3, row.Gamma4, row.Gamma5 = calc_lambda_eta_sum(auto_correlation) # put the possibly clipped table into the file root.appendChild(bank.sngl_inspiral_table) # Add root-level scalar params root.appendChild(ligolw_param.Param.from_pyvalue('filter_length', bank.filter_length)) root.appendChild(ligolw_param.Param.from_pyvalue('gate_threshold', bank.gate_threshold)) root.appendChild(ligolw_param.Param.from_pyvalue('logname', bank.logname or "")) root.appendChild(ligolw_param.Param.from_pyvalue('snr_threshold', bank.snr_threshold)) root.appendChild(ligolw_param.Param.from_pyvalue('template_bank_filename', bank.template_bank_filename)) root.appendChild(ligolw_param.Param.from_pyvalue('bank_id', bank.bank_id)) root.appendChild(ligolw_param.Param.from_pyvalue('new_deltaf', bank.newdeltaF)) root.appendChild(ligolw_param.Param.from_pyvalue('working_f_low', bank.working_f_low)) root.appendChild(ligolw_param.Param.from_pyvalue('f_low', bank.f_low)) root.appendChild(ligolw_param.Param.from_pyvalue('sample_rate_max', int(bank.sample_rate_max))) root.appendChild(ligolw_param.Param.from_pyvalue('gstlal_fir_whiten', os.environ['GSTLAL_FIR_WHITEN'])) root.appendChild(ligolw_param.Param.from_pyvalue('bank_type', bank.bank_type)) # Add root-level arrays # FIXME: ligolw format now supports complex-valued data root.appendChild(ligolw_array.Array.build('autocorrelation_bank_real', bank.autocorrelation_bank.real)) root.appendChild(ligolw_array.Array.build('autocorrelation_bank_imag', bank.autocorrelation_bank.imag)) root.appendChild(ligolw_array.Array.build('autocorrelation_mask', bank.autocorrelation_mask)) root.appendChild(ligolw_array.Array.build('sigmasq', numpy.array(bank.sigmasq))) root.appendChild(ligolw_array.Array.build('bank_correlation_matrix_real', bank.bank_correlation_matrix.real)) root.appendChild(ligolw_array.Array.build('bank_correlation_matrix_imag', bank.bank_correlation_matrix.imag)) # Write bank fragments for i, frag in enumerate(bank.bank_fragments): # Start new bank fragment container el = root.appendChild(ligolw.LIGO_LW()) # Add scalar params el.appendChild(ligolw_param.Param.from_pyvalue('rate', int(frag.rate))) el.appendChild(ligolw_param.Param.from_pyvalue('start', frag.start)) el.appendChild(ligolw_param.Param.from_pyvalue('end', frag.end)) # Add arrays el.appendChild(ligolw_array.Array.build('chifacs', frag.chifacs)) if frag.mix_matrix is not None: el.appendChild(ligolw_array.Array.build('mix_matrix', frag.mix_matrix)) el.appendChild(ligolw_array.Array.build('orthogonal_template_bank', frag.orthogonal_template_bank)) if frag.singular_values is not None: el.appendChild(ligolw_array.Array.build('singular_values', frag.singular_values)) if frag.sum_of_squares_weights is not None: el.appendChild(ligolw_array.Array.build('sum_of_squares_weights', frag.sum_of_squares_weights)) # put a copy of the processed PSD file in # FIXME in principle this could be different for each bank included in # this file, but we only put one here psd = psd_input[bank.sngl_inspiral_table[0].ifo] lal.series.make_psd_xmldoc({bank.sngl_inspiral_table[0].ifo: psd}, lw) # Write to file ligolw_utils.write_filename(xmldoc, filename, verbose = verbose)
[docs]def read_banks(filename, contenthandler, verbose = False): """Read SVD banks from a LIGO_LW xml file.""" # Load document xmldoc = ligolw_utils.load_url(filename, contenthandler = contenthandler, verbose = verbose) banks = [] # FIXME in principle this could be different for each bank included in # this file, but we only put one in the file for now # FIXME, right now there is only one instrument so we just pull out the # only psd there is try: raw_psd = list(lal.series.read_psd_xmldoc(xmldoc).values())[0] except ValueError: # the bank file does not contain psd ligolw element. raw_psd = None for root in (elem for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName) if elem.hasAttribute(u"Name") and elem.Name == "gstlal_svd_bank_Bank"): # Create new SVD bank object bank = Bank.__new__(Bank) # Read sngl inspiral table bank.sngl_inspiral_table = lsctables.SnglInspiralTable.get_table(root) bank.sngl_inspiral_table.parentNode.removeChild(bank.sngl_inspiral_table) # Read root-level scalar parameters bank.filter_length = ligolw_param.get_pyvalue(root, 'filter_length') bank.gate_threshold = ligolw_param.get_pyvalue(root, 'gate_threshold') bank.logname = ligolw_param.get_pyvalue(root, 'logname') or None bank.snr_threshold = ligolw_param.get_pyvalue(root, 'snr_threshold') bank.template_bank_filename = ligolw_param.get_pyvalue(root, 'template_bank_filename') bank.bank_id = ligolw_param.get_pyvalue(root, 'bank_id') bank.bank_type = ligolw_param.get_pyvalue(root, 'bank_type') try: bank.newdeltaF = ligolw_param.get_pyvalue(root, 'new_deltaf') bank.working_f_low = ligolw_param.get_pyvalue(root, 'working_f_low') bank.f_low = ligolw_param.get_pyvalue(root, 'f_low') bank.sample_rate_max = ligolw_param.get_pyvalue(root, 'sample_rate_max') except ValueError: pass # Read root-level arrays bank.autocorrelation_bank = ligolw_array.get_array(root, 'autocorrelation_bank_real').array + 1j * ligolw_array.get_array(root, 'autocorrelation_bank_imag').array bank.autocorrelation_mask = ligolw_array.get_array(root, 'autocorrelation_mask').array bank.sigmasq = ligolw_array.get_array(root, 'sigmasq').array bank_correlation_real = ligolw_array.get_array(root, 'bank_correlation_matrix_real').array bank_correlation_imag = ligolw_array.get_array(root, 'bank_correlation_matrix_imag').array bank.bank_correlation_matrix = bank_correlation_real + 1j * bank_correlation_imag # prepare the horizon distance factors bank.horizon_factors = dict((row.template_id, sigmasq**.5) for row, sigmasq in zip(bank.sngl_inspiral_table, bank.sigmasq)) if raw_psd is not None: # reproduce the whitening psd and attach a reference to the psd bank.processed_psd = condition_psd(raw_psd, bank.newdeltaF, minfs = (bank.working_f_low, bank.f_low), maxfs = (bank.sample_rate_max / 2.0 * 0.90, bank.sample_rate_max / 2.0)) else: bank.processed_psd = None # Read bank fragments bank.bank_fragments = [] for el in (node for node in root.childNodes if node.tagName == ligolw.LIGO_LW.tagName): frag = BankFragment( rate = ligolw_param.get_pyvalue(el, 'rate'), start = ligolw_param.get_pyvalue(el, 'start'), end = ligolw_param.get_pyvalue(el, 'end') ) # Read arrays frag.chifacs = ligolw_array.get_array(el, 'chifacs').array try: frag.mix_matrix = ligolw_array.get_array(el, 'mix_matrix').array except ValueError: frag.mix_matrix = None frag.orthogonal_template_bank = ligolw_array.get_array(el, 'orthogonal_template_bank').array try: frag.singular_values = ligolw_array.get_array(el, 'singular_values').array except ValueError: frag.singular_values = None try: frag.sum_of_squares_weights = ligolw_array.get_array(el, 'sum_of_squares_weights').array except ValueError: frag.sum_of_squares_weights = None bank.bank_fragments.append(frag) banks.append(bank) template_id, func = horizon_distance_func(banks) template_id = abs(template_id) # make sure horizon_distance_func did not pick the noise model template horizon_norm = None for bank in banks: if template_id in bank.horizon_factors and bank.bank_type == "signal_model": assert horizon_norm is None horizon_norm = bank.horizon_factors[template_id] for bank in banks: bank.horizon_distance_func = func bank.horizon_factors = dict((tid, f / horizon_norm) for (tid, f) in bank.horizon_factors.items()) xmldoc.unlink() return banks
[docs]def svdbank_templates_mapping(filenames, contenthandler, verbose = False): """ From a list of the names of files containing SVD bank objects, construct a dictionary mapping filename to list of sngl_inspiral templates in that file. Typically this mapping is inverted through the use of some sort of "template identity" function to map each template to the filename that contains that template. Example: Assuming the (mass1, mass2) tuple is known to uniquely identify the templates >>> def template_id(row): ... return row.mass1, row.mass2 ... >>> mapping = svdbank_templates_mapping([], DefaultContentHandler) >>> template_to_filename = dict((template_id(tempate), filename) for filename, templates in mapping.items() for template in templates) """ mapping = {} for n, filename in enumerate(filenames, start = 1): if verbose: print("%d/%d:" % (n, len(filenames)), file=sys.stderr) mapping[filename] = sum((bank.sngl_inspiral_table for bank in read_banks(filename, contenthandler, verbose = verbose)), []) return mapping
[docs]def preferred_horizon_distance_template(banks): template_id, m1, m2, s1z, s2z = min((row.template_id, row.mass1, row.mass2, row.spin1z, row.spin2z) for bank in banks for row in bank.sngl_inspiral_table) return template_id, m1, m2, s1z, s2z
[docs]def horizon_distance_func(banks): """ Takes a dictionary of objects returned by read_banks keyed by instrument """ # span is [15 Hz, 0.85 * Nyquist frequency] # find the Nyquist frequency for the PSD to be used for each # instrument. require them to all match nyquists = set((max(bank.get_rates())/2. for bank in banks)) if len(nyquists) != 1: warnings.warn("all banks should have the same Nyquist frequency to define a consistent horizon distance function (got %s)" % ", ".join("%g" % rate for rate in sorted(nyquists))) # assume default 4 s PSD. this is not required to be correct, but # for best accuracy it should not be larger than the true value and # for best performance it should not be smaller than the true # value. deltaF = 1. / 4. # use the minimum template id as the cannonical horizon function template_id, m1, m2, s1z, s2z = preferred_horizon_distance_template(banks) return template_id, HorizonDistance(15.0, 0.85 * max(nyquists), deltaF, m1, m2, spin1 = (0., 0., s1z), spin2 = (0., 0., s2z))
[docs]def calc_lambda_eta_sum(auto_correlation): acl = len(auto_correlation) norm_chisq = sum(2 - 2 * abs(auto_correlation)**2) center_ind = int((acl - 1)/2.) covmat_size = acl # the following calculation of the covariance matrix is based on Eqs.(45, # 46) in the technical notes of DCC-G2200635. covmat_cplx = numpy.zeros((covmat_size, covmat_size), dtype=numpy.complex128) for j, row in enumerate(covmat_cplx): row_start = max(j - center_ind, 0) row_end = min(j - center_ind + acl, covmat_size) R_start = max(center_ind - j, 0) R_end = min(center_ind - j + acl, covmat_size) row[row_start:row_end] += auto_correlation[R_start:R_end] covmat_cplx -= numpy.outer(auto_correlation.conj(), auto_correlation) covmat_real = numpy.vstack([numpy.hstack([covmat_cplx.real, covmat_cplx.imag]), numpy.hstack([-covmat_cplx.imag, covmat_cplx.real])]) auto_correlation_real = numpy.hstack([auto_correlation.real, auto_correlation.imag]) # the following calculation is based on Eqs.(69)- (72) in the technical # notes of DCC-G2200635. lambda_sum = norm_chisq lambdasq_sum = numpy.trace(covmat_real.dot(covmat_real)) lambda_etasq_sum = sum(auto_correlation_real**2) lambdasq_etasq_sum = auto_correlation_real.T.dot(covmat_real.dot(auto_correlation_real)) return lambda_sum, lambdasq_sum, lambda_etasq_sum, lambdasq_etasq_sum