Coverage for bilby/gw/likelihood/roq.py: 92%
610 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
2import json
4import numpy as np
6from .base import GravitationalWaveTransient
7from ...core.utils import BilbyJsonEncoder, decode_bilby_json
8from ...core.utils import (
9 logger, create_frequency_series, speed_of_light, radius_of_earth
10)
11from ..prior import CBCPriorDict
12from ..utils import ln_i0
15class ROQGravitationalWaveTransient(GravitationalWaveTransient):
16 """A reduced order quadrature likelihood object
18 This uses the method described in Smith et al., (2016) Phys. Rev. D 94,
19 044031. A public repository of the ROQ data is available from
20 https://git.ligo.org/lscsoft/ROQ_data.
22 Parameters
23 ==========
24 interferometers: list, bilby.gw.detector.InterferometerList
25 A list of `bilby.detector.Interferometer` instances - contains the
26 detector data and power spectral densities
27 waveform_generator: `bilby.waveform_generator.WaveformGenerator`
28 An object which computes the frequency-domain strain of the signal,
29 given some set of parameters
30 linear_matrix: str, array_like
31 Either a string point to the file from which to load the linear_matrix
32 array, or the array itself.
33 quadratic_matrix: str, array_like
34 Either a string point to the file from which to load the
35 quadratic_matrix array, or the array itself.
36 roq_params: str, array_like
37 Parameters describing the domain of validity of the ROQ basis.
38 roq_params_check: bool
39 If true, run tests using the roq_params to check the prior and data are
40 valid for the ROQ
41 roq_scale_factor: float
42 The ROQ scale factor used.
43 parameter_conversion: func, optional
44 Function to update self.parameters before bases are selected based on
45 the values of self.parameters. This enables a user to switch bases
46 based on the values of parameters which are not directly used for
47 sampling.
48 priors: dict, bilby.prior.PriorDict
49 A dictionary of priors containing at least the geocent_time prior
50 Warning: when using marginalisation the dict is overwritten which will change the
51 the dict you are passing in. If this behaviour is undesired, pass `priors.copy()`.
52 time_marginalization: bool, optional
53 If true, marginalize over time in the likelihood.
54 The spacing of time samples can be specified through delta_tc.
55 If using time marginalisation and jitter_time is True a "jitter"
56 parameter is added to the prior which modifies the position of the
57 grid of times.
58 jitter_time: bool, optional
59 Whether to introduce a `time_jitter` parameter. This avoids either
60 missing the likelihood peak, or introducing biases in the
61 reconstructed time posterior due to an insufficient sampling frequency.
62 Default is False, however using this parameter is strongly encouraged.
63 delta_tc: float, optional
64 The spacing of time samples for time marginalization. If not specified,
65 it is determined based on the signal-to-noise ratio of signal.
66 distance_marginalization_lookup_table: (dict, str), optional
67 If a dict, dictionary containing the lookup_table, distance_array,
68 (distance) prior_array, and reference_distance used to construct
69 the table.
70 If a string the name of a file containing these quantities.
71 The lookup table is stored after construction in either the
72 provided string or a default location:
73 '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
74 reference_frame: (str, bilby.gw.detector.InterferometerList, list), optional
75 Definition of the reference frame for the sky location.
76 - "sky": sample in RA/dec, this is the default
77 - e.g., "H1L1", ["H1", "L1"], InterferometerList(["H1", "L1"]):
78 sample in azimuth and zenith, `azimuth` and `zenith` defined in the
79 frame where the z-axis is aligned the the vector connecting H1
80 and L1.
81 time_reference: str, optional
82 Name of the reference for the sampled time parameter.
83 - "geocent"/"geocenter": sample in the time at the Earth's center,
84 this is the default
85 - e.g., "H1": sample in the time of arrival at H1
87 """
88 def __init__(
89 self, interferometers, waveform_generator, priors,
90 weights=None, linear_matrix=None, quadratic_matrix=None,
91 roq_params=None, roq_params_check=True, roq_scale_factor=1,
92 distance_marginalization=False, phase_marginalization=False,
93 time_marginalization=False, jitter_time=True, delta_tc=None,
94 distance_marginalization_lookup_table=None,
95 reference_frame="sky", time_reference="geocenter",
96 parameter_conversion=None
98 ):
99 self._delta_tc = delta_tc
100 super(ROQGravitationalWaveTransient, self).__init__(
101 interferometers=interferometers,
102 waveform_generator=waveform_generator, priors=priors,
103 distance_marginalization=distance_marginalization,
104 phase_marginalization=phase_marginalization,
105 time_marginalization=time_marginalization,
106 distance_marginalization_lookup_table=distance_marginalization_lookup_table,
107 jitter_time=jitter_time,
108 reference_frame=reference_frame,
109 time_reference=time_reference
110 )
112 self.roq_params_check = roq_params_check
113 self.roq_scale_factor = roq_scale_factor
114 if isinstance(roq_params, np.ndarray) or roq_params is None:
115 self.roq_params = roq_params
116 elif isinstance(roq_params, str):
117 self.roq_params_file = roq_params
118 self.roq_params = np.genfromtxt(roq_params, names=True)
119 else:
120 raise TypeError("roq_params should be array or str")
121 if isinstance(weights, dict):
122 self.weights = weights
123 elif isinstance(weights, str):
124 self.weights = self.load_weights(weights)
125 else:
126 is_hdf5_linear = isinstance(linear_matrix, str) and linear_matrix.endswith('.hdf5')
127 linear_matrix = self._parse_basis(linear_matrix, 'linear')
128 is_hdf5_quadratic = isinstance(quadratic_matrix, str) and quadratic_matrix.endswith('.hdf5')
129 quadratic_matrix = self._parse_basis(quadratic_matrix, 'quadratic')
130 # retrieve roq params from a basis file if it is .hdf5
131 if self.roq_params is None:
132 if is_hdf5_linear:
133 self.roq_params = np.array(
134 [(linear_matrix['minimum_frequency_hz'][()],
135 linear_matrix['maximum_frequency_hz'][()],
136 linear_matrix['duration_s'][()])],
137 dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
138 )
139 if is_hdf5_quadratic:
140 if self.roq_params is None:
141 self.roq_params = np.array(
142 [(quadratic_matrix['minimum_frequency_hz'][()],
143 quadratic_matrix['maximum_frequency_hz'][()],
144 quadratic_matrix['duration_s'][()])],
145 dtype=[('flow', float), ('fhigh', float), ('seglen', float)]
146 )
147 else:
148 self.roq_params['flow'] = max(
149 self.roq_params['flow'], quadratic_matrix['minimum_frequency_hz'][()]
150 )
151 self.roq_params['fhigh'] = min(
152 self.roq_params['fhigh'], quadratic_matrix['maximum_frequency_hz'][()]
153 )
154 self.roq_params['seglen'] = min(
155 self.roq_params['seglen'], quadratic_matrix['duration_s'][()]
156 )
157 if self.roq_params is not None:
158 for ifo in self.interferometers:
159 self.perform_roq_params_check(ifo)
161 self.weights = dict()
162 self._set_weights(linear_matrix=linear_matrix, quadratic_matrix=quadratic_matrix)
163 if is_hdf5_linear:
164 linear_matrix.close()
165 if is_hdf5_quadratic:
166 quadratic_matrix.close()
168 self.number_of_bases_linear = len(self.weights[f'{self.interferometers[0].name}_linear'])
169 self.number_of_bases_quadratic = len(self.weights[f'{self.interferometers[0].name}_quadratic'])
170 self._cache = dict(parameters=None, basis_number_linear=None, basis_number_quadratic=None)
171 self.parameter_conversion = parameter_conversion
173 for basis_type in ['linear', 'quadratic']:
174 number_of_bases = getattr(self, f'number_of_bases_{basis_type}')
175 if number_of_bases > 1:
176 self._verify_numbers_of_prior_ranges_and_frequency_nodes(basis_type)
177 else:
178 self._check_frequency_nodes_exist_for_single_basis(basis_type)
179 self._verify_prior_ranges(basis_type)
181 self._set_unique_frequency_nodes_and_inverse()
182 # need to fill waveform_arguments here if single basis is used, as they will never be updated.
183 if self.number_of_bases_linear == 1 and self.number_of_bases_quadratic == 1:
184 frequency_nodes, linear_indices, quadratic_indices = \
185 self._unique_frequency_nodes_and_inverse[0][0]
186 self._waveform_generator.waveform_arguments['frequency_nodes'] = frequency_nodes
187 self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices
188 self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices
190 def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type):
191 """
192 Check if self.weights contains lists of prior ranges and frequency nodes, and their sizes are equal to the
193 number of bases.
195 Parameters
196 ==========
197 basis_type: str
199 """
200 number_of_bases = getattr(self, f'number_of_bases_{basis_type}')
201 key = f'prior_range_{basis_type}'
202 try:
203 prior_ranges = self.weights[key]
204 except KeyError:
205 raise AttributeError(
206 f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".')
207 else:
208 for param_name in prior_ranges:
209 if len(prior_ranges[param_name]) != number_of_bases:
210 raise ValueError(
211 f'The number of prior ranges for "{param_name}" does not '
212 f'match the number of {basis_type} bases')
213 key = f'frequency_nodes_{basis_type}'
214 try:
215 frequency_nodes = self.weights[key]
216 except KeyError:
217 raise AttributeError(
218 f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".')
219 else:
220 if len(frequency_nodes) != number_of_bases:
221 raise ValueError(
222 f'The number of arrays of frequency nodes does not match the number of {basis_type} bases')
224 def _verify_prior_ranges(self, basis_type):
225 """Check if the union of prior ranges is within the ROQ basis bounds.
227 Parameters
228 ==========
229 basis_type: str
231 """
232 key = f'prior_range_{basis_type}'
233 if key not in self.weights:
234 return
235 prior_ranges = self.weights[key]
236 for param_name, prior_ranges_of_this_param in prior_ranges.items():
237 prior_minimum = self.priors[param_name].minimum
238 basis_minimum = np.min(prior_ranges_of_this_param[:, 0])
239 if prior_minimum < basis_minimum:
240 raise BilbyROQParamsRangeError(
241 f"Prior minimum of {param_name} {prior_minimum} less "
242 f"than ROQ basis bound {basis_minimum}"
243 )
245 prior_maximum = self.priors[param_name].maximum
246 basis_maximum = np.max(prior_ranges_of_this_param[:, 1])
247 if prior_maximum > basis_maximum:
248 raise BilbyROQParamsRangeError(
249 f"Prior maximum of {param_name} {prior_maximum} greater "
250 f"than ROQ basis bound {basis_maximum}"
251 )
253 def _check_frequency_nodes_exist_for_single_basis(self, basis_type):
254 """
255 For a single-basis case, frequency nodes should be contained in self._waveform_generator.waveform_arguments or
256 self.weights. This method checks if it is the case and raise AttributeError if not. This method also adds
257 frequency nodes to self._waveform_generator.waveform_arguments or self.weights from the other.
259 Parameters
260 ==========
261 basis_type: str
263 """
264 key = f'frequency_nodes_{basis_type}'
265 if not (key in self.weights or key in self._waveform_generator.waveform_arguments):
266 raise AttributeError(f'{key} should be contained in weights or waveform arguments.')
267 elif key not in self._waveform_generator.waveform_arguments:
268 self._waveform_generator.waveform_arguments[key] = self.weights[key][0]
269 elif key not in self.weights:
270 self.weights[key] = [self._waveform_generator.waveform_arguments[key]]
272 def _set_unique_frequency_nodes_and_inverse(self):
273 """Set unique frequency nodes and indices to recover linear and quadratic frequency nodes for each combination
274 of linear and quadratic bases
275 """
276 self._unique_frequency_nodes_and_inverse = []
277 for idx_linear in range(self.number_of_bases_linear):
278 tmp = []
279 frequency_nodes_linear = self.weights['frequency_nodes_linear'][idx_linear]
280 size_linear = len(frequency_nodes_linear)
281 for idx_quadratic in range(self.number_of_bases_quadratic):
282 frequency_nodes_quadratic = self.weights['frequency_nodes_quadratic'][idx_quadratic]
283 frequency_nodes_unique, original_indices = np.unique(
284 np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)),
285 return_inverse=True
286 )
287 linear_indices = original_indices[:size_linear]
288 quadratic_indices = original_indices[size_linear:]
289 tmp.append(
290 (frequency_nodes_unique, linear_indices, quadratic_indices)
291 )
292 self._unique_frequency_nodes_and_inverse.append(tmp)
294 def _setup_time_marginalization(self):
295 if self._delta_tc is None:
296 self._delta_tc = self._get_time_resolution()
297 tcmin = self.priors['geocent_time'].minimum
298 tcmax = self.priors['geocent_time'].maximum
299 number_of_time_samples = int(np.ceil((tcmax - tcmin) / self._delta_tc))
300 # adjust delta tc so that the last time sample has an equal weight
301 self._delta_tc = (tcmax - tcmin) / number_of_time_samples
302 logger.info(
303 "delta tc for time marginalization = {} seconds.".format(self._delta_tc))
304 self._times = tcmin + self._delta_tc / 2. + np.arange(number_of_time_samples) * self._delta_tc
305 self._beam_pattern_reference_time = (tcmin + tcmax) / 2.
307 @staticmethod
308 def _parse_basis(basis, basis_type):
309 """
310 Parse basis and format it to an hdf5-like object
312 Parameters
313 ----------
314 basis : array-like or str
315 array-like basis or path to file
316 basis_type : str
317 'linear' or 'quadratic'
319 Returns
320 -------
321 basis : hdf5-like object
323 """
324 if basis_type not in ['linear', 'quadratic']:
325 raise ValueError(f'basis_type {basis_type} not recognized')
326 if isinstance(basis, str):
327 logger.info(f'Loading {basis_type}_matrix from {basis}')
328 format = basis.split('.')[-1]
329 if format == 'npy':
330 basis = {f'basis_{basis_type}': {'0': {'basis': np.load(basis)}}}
331 elif format == 'hdf5':
332 import h5py
333 basis = h5py.File(basis, 'r')
334 else:
335 raise IOError(f'Format {format} not recognized.')
336 elif isinstance(basis, np.ndarray):
337 basis = {f'basis_{basis_type}': {'0': {'basis': basis.T}}}
338 else:
339 raise TypeError('basis needs to be str or np.ndarray')
340 return basis
342 def _select_prior_ranges(self, prior_ranges):
343 """
344 Select prior ranges which have intersection with self.priors
346 Parameters
347 ----------
348 prior_ranges : dict
349 dictionary whose keys are parameter names and values are ndarray of
350 their prior ranges
352 Returns
353 -------
354 idxs_in_prior_range : ndarray
355 indexes of selected prior ranges
356 selected_prior_ranges : dict
358 """
359 param_names = list(prior_ranges.keys())
360 number_of_prior_ranges = len(prior_ranges[param_names[0]])
361 in_prior_range = np.ones(number_of_prior_ranges, dtype=bool)
362 for param_name in param_names:
363 try:
364 prior = self.priors[param_name]
365 except KeyError:
366 continue
367 prior_ranges_of_this_param = prior_ranges[param_name]
368 in_prior_range *= \
369 (prior_ranges_of_this_param[:, 1] >= prior.minimum) * \
370 (prior_ranges_of_this_param[:, 0] <= prior.maximum)
371 idxs_in_prior_range = np.arange(number_of_prior_ranges)[in_prior_range]
372 return idxs_in_prior_range, \
373 dict((param_name, prior_ranges[param_name][idxs_in_prior_range])
374 for param_name in param_names)
376 def _update_basis(self):
377 """
378 Update basis and frequency nodes depending on the curret values of parameters
380 This updates
381 - self._cache
382 - frequency_nodes_linear/quadratic in self._waveform_generator.waveform_arguments
384 """
385 parameters = self.parameters.copy()
386 if self.parameter_conversion is not None:
387 parameters = self.parameter_conversion(parameters)
388 for basis_type, number_of_bases in zip(
389 ['linear', 'quadratic'], [self.number_of_bases_linear, self.number_of_bases_quadratic]
390 ):
391 basis_number_key = f'basis_number_{basis_type}'
392 if number_of_bases == 1:
393 self._cache[basis_number_key] = 0
394 continue
395 in_prior_range = np.ones(number_of_bases, dtype=bool)
396 prior_range_key = f'prior_range_{basis_type}'
397 for param_name in self.weights[prior_range_key]:
398 if param_name not in parameters:
399 continue
400 in_prior_range *= \
401 (self.weights[prior_range_key][param_name][:, 0] <= parameters[param_name]) * \
402 (self.weights[prior_range_key][param_name][:, 1] >= parameters[param_name])
403 self._cache[basis_number_key] = np.arange(number_of_bases)[in_prior_range][0]
404 basis_number_linear = self._cache['basis_number_linear']
405 basis_number_quadratic = self._cache['basis_number_quadratic']
406 frequency_nodes, linear_indices, quadratic_indices = \
407 self._unique_frequency_nodes_and_inverse[basis_number_linear][basis_number_quadratic]
408 self._waveform_generator.waveform_arguments['frequency_nodes'] = frequency_nodes
409 self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices
410 self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices
411 self._cache['parameters'] = self.parameters.copy()
413 @property
414 def basis_number_linear(self):
415 if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1:
416 if self.parameters != self._cache['parameters']:
417 self._update_basis()
418 return self._cache['basis_number_linear']
419 else:
420 return 0
422 @property
423 def basis_number_quadratic(self):
424 if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1:
425 if self.parameters != self._cache['parameters']:
426 self._update_basis()
427 return self._cache['basis_number_quadratic']
428 else:
429 return 0
431 @property
432 def waveform_generator(self):
433 if getattr(self, 'number_of_bases_linear', 1) > 1 or getattr(self, 'number_of_bases_quadratic', 1) > 1:
434 if self.parameters != self._cache['parameters']:
435 self._update_basis()
436 return self._waveform_generator
438 @waveform_generator.setter
439 def waveform_generator(self, waveform_generator):
440 self._waveform_generator = waveform_generator
442 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True):
443 """
444 Compute the snrs for ROQ
446 Parameters
447 ==========
448 waveform_polarizations: waveform
449 interferometer: bilby.gw.detector.Interferometer
451 """
452 if self.time_marginalization:
453 time_ref = self._beam_pattern_reference_time
454 else:
455 time_ref = self.parameters['geocent_time']
457 frequency_nodes = self.waveform_generator.waveform_arguments['frequency_nodes']
458 linear_indices = self.waveform_generator.waveform_arguments['linear_indices']
459 quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices']
460 size_linear = len(linear_indices)
461 size_quadratic = len(quadratic_indices)
462 h_linear = np.zeros(size_linear, dtype=complex)
463 h_quadratic = np.zeros(size_quadratic, dtype=complex)
464 for mode in waveform_polarizations['linear']:
465 response = interferometer.antenna_response(
466 self.parameters['ra'], self.parameters['dec'],
467 time_ref,
468 self.parameters['psi'],
469 mode
470 )
471 h_linear += waveform_polarizations['linear'][mode] * response
472 h_quadratic += waveform_polarizations['quadratic'][mode] * response
474 calib_factor = interferometer.calibration_model.get_calibration_factor(
475 frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
476 h_linear *= calib_factor[linear_indices]
477 h_quadratic *= calib_factor[quadratic_indices]
479 optimal_snr_squared = np.vdot(
480 np.abs(h_quadratic)**2,
481 self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic]
482 )
484 dt = interferometer.time_delay_from_geocenter(
485 self.parameters['ra'], self.parameters['dec'], time_ref)
486 dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time
487 ifo_time = dt_geocent + dt
489 indices, in_bounds = self._closest_time_indices(
490 ifo_time, self.weights['time_samples'])
491 if not in_bounds:
492 logger.debug("SNR calculation error: requested time at edge of ROQ time samples")
493 d_inner_h = -np.inf
494 complex_matched_filter_snr = -np.inf
495 else:
496 d_inner_h_tc_array = np.einsum(
497 'i,ji->j', np.conjugate(h_linear),
498 self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices])
500 d_inner_h = self._interp_five_samples(
501 self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time)
503 with np.errstate(invalid="ignore"):
504 complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5)
506 if return_array and self.time_marginalization:
507 ifo_times = self._times - interferometer.strain_data.start_time
508 ifo_times += dt
509 if self.jitter_time:
510 ifo_times += self.parameters['time_jitter']
511 d_inner_h_array = self._calculate_d_inner_h_array(ifo_times, h_linear, interferometer.name)
512 else:
513 d_inner_h_array = None
515 return self._CalculatedSNRs(
516 d_inner_h=d_inner_h,
517 optimal_snr_squared=optimal_snr_squared.real,
518 complex_matched_filter_snr=complex_matched_filter_snr,
519 d_inner_h_array=d_inner_h_array,
520 )
522 @staticmethod
523 def _closest_time_indices(time, samples):
524 """
525 Get the closest five times
527 Parameters
528 ==========
529 time: float
530 Time to check
531 samples: array-like
532 Available times
534 Returns
535 =======
536 indices: list
537 Indices nearest to time
538 in_bounds: bool
539 Whether the indices are for valid times
540 """
541 closest = int((time - samples[0]) / (samples[1] - samples[0]))
542 indices = [closest + ii for ii in [-2, -1, 0, 1, 2]]
543 in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size)
544 return indices, in_bounds
546 @staticmethod
547 def _interp_five_samples(time_samples, values, time):
548 """
549 Interpolate a function of time with its values at the closest five times.
550 The algorithm is explained in https://dcc.ligo.org/T2100224.
552 Parameters
553 ==========
554 time_samples: array-like
555 Closest 5 times
556 values: array-like
557 The values of the function at closest 5 times
558 time: float
559 Time at which the function is calculated
561 Returns
562 =======
563 value: float
564 The value of the function at the input time
565 """
566 r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4.
567 r2 = values[2] - 2. * values[3] + values[4]
568 a = (time_samples[3] - time) / (time_samples[1] - time_samples[0])
569 b = 1. - a
570 c = (a**3. - a) / 6.
571 d = (b**3. - b) / 6.
572 return a * values[2] + b * values[3] + c * r1 + d * r2
574 def _calculate_d_inner_h_array(self, times, h_linear, ifo_name):
575 """
576 Calculate d_inner_h at regularly-spaced time samples. Each value is
577 interpolated from the nearest 5 samples with the algorithm explained in
578 https://dcc.ligo.org/T2100224.
580 Parameters
581 ==========
582 times: array-like
583 Regularly-spaced time samples at which d_inner_h are calculated.
584 h_linear: array-like
585 Waveforms at linear frequency nodes
586 ifo_name: str
588 Returns
589 =======
590 d_inner_h_array: array-like
591 """
592 roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0]
593 times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space
594 closest_idxs = np.floor(times_per_roq_time_space).astype(int)
595 # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time
596 # spacing is larger than 5 times the ROQ time spacing.
597 weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear]
598 h_linear_conj = np.conjugate(h_linear)
599 if (times[1] - times[0]) / roq_time_space > 5:
600 d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj)
601 d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj)
602 d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj)
603 d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj)
604 d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj)
605 else:
606 d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj)
607 d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2]
608 d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1]
609 d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs]
610 d_inner_h_p1 = d_inner_h_at_roq_time_samples[closest_idxs + 1]
611 d_inner_h_p2 = d_inner_h_at_roq_time_samples[closest_idxs + 2]
612 # quantities required for spline interpolation
613 b = times_per_roq_time_space - closest_idxs
614 a = 1. - b
615 c = (a**3. - a) / 6.
616 d = (b**3. - b) / 6.
617 r1 = (-d_inner_h_m2 + 8. * d_inner_h_m1 - 14. * d_inner_h_0 + 8. * d_inner_h_p1 - d_inner_h_p2) / 4.
618 r2 = d_inner_h_0 - 2. * d_inner_h_p1 + d_inner_h_p2
619 return a * d_inner_h_0 + b * d_inner_h_p1 + c * r1 + d * r2
621 def perform_roq_params_check(self, ifo=None):
622 """ Perform checking that the prior and data are valid for the ROQ
624 Parameters
625 ==========
626 ifo: bilby.gw.detector.Interferometer
627 The interferometer
628 """
629 if self.roq_params_check is False:
630 logger.warning("No ROQ params checking performed")
631 return
632 else:
633 if getattr(self, "roq_params_file", None) is not None:
634 msg = ("Check ROQ params {} with roq_scale_factor={}"
635 .format(self.roq_params_file, self.roq_scale_factor))
636 else:
637 msg = ("Check ROQ params with roq_scale_factor={}"
638 .format(self.roq_scale_factor))
639 logger.info(msg)
641 roq_params = self.roq_params
642 roq_minimum_frequency = roq_params['flow'] * self.roq_scale_factor
643 roq_maximum_frequency = roq_params['fhigh'] * self.roq_scale_factor
644 roq_segment_length = roq_params['seglen'] / self.roq_scale_factor
645 try:
646 roq_minimum_chirp_mass = roq_params['chirpmassmin'] / self.roq_scale_factor
647 except ValueError:
648 roq_minimum_chirp_mass = None
649 try:
650 roq_maximum_chirp_mass = roq_params['chirpmassmax'] / self.roq_scale_factor
651 except ValueError:
652 roq_maximum_chirp_mass = None
653 try:
654 roq_minimum_component_mass = roq_params['compmin'] / self.roq_scale_factor
655 except ValueError:
656 roq_minimum_component_mass = None
658 if ifo.maximum_frequency > roq_maximum_frequency:
659 raise BilbyROQParamsRangeError(
660 "Requested maximum frequency {} larger than ROQ basis fhigh {}"
661 .format(ifo.maximum_frequency, roq_maximum_frequency)
662 )
663 if ifo.minimum_frequency < roq_minimum_frequency:
664 raise BilbyROQParamsRangeError(
665 "Requested minimum frequency {} lower than ROQ basis flow {}"
666 .format(ifo.minimum_frequency, roq_minimum_frequency)
667 )
668 if ifo.strain_data.duration != roq_segment_length:
669 raise BilbyROQParamsRangeError(
670 "Requested duration differs from ROQ basis seglen")
672 priors = self.priors
673 if isinstance(priors, CBCPriorDict) is False:
674 logger.warning("Unable to check ROQ parameter bounds: priors not understood")
675 return
677 if roq_minimum_chirp_mass is not None:
678 if priors.minimum_chirp_mass is None:
679 logger.warning("Unable to check minimum chirp mass ROQ bounds")
680 elif priors.minimum_chirp_mass < roq_minimum_chirp_mass:
681 raise BilbyROQParamsRangeError(
682 "Prior minimum chirp mass {} less than ROQ basis bound {}"
683 .format(priors.minimum_chirp_mass, roq_minimum_chirp_mass)
684 )
686 if roq_maximum_chirp_mass is not None:
687 if priors.maximum_chirp_mass is None:
688 logger.warning("Unable to check maximum_chirp mass ROQ bounds")
689 elif priors.maximum_chirp_mass > roq_maximum_chirp_mass:
690 raise BilbyROQParamsRangeError(
691 "Prior maximum chirp mass {} greater than ROQ basis bound {}"
692 .format(priors.maximum_chirp_mass, roq_maximum_chirp_mass)
693 )
695 if roq_minimum_component_mass is not None:
696 if priors.minimum_component_mass is None:
697 logger.warning("Unable to check minimum component mass ROQ bounds")
698 elif priors.minimum_component_mass < roq_minimum_component_mass:
699 raise BilbyROQParamsRangeError(
700 "Prior minimum component mass {} less than ROQ basis bound {}"
701 .format(priors.minimum_component_mass, roq_minimum_component_mass)
702 )
704 def _set_weights(self, linear_matrix, quadratic_matrix):
705 """
706 Setup the time-dependent ROQ weights.
708 Parameters
709 ==========
710 linear_matrix, quadratic_matrix: dictionary or h5py.File
711 linear and quadratic basis
713 """
714 time_space = self._get_time_resolution()
715 number_of_time_samples = int(self.interferometers.duration / time_space)
716 earth_light_crossing_time = 2 * radius_of_earth / speed_of_light + 5 * time_space
717 start_idx = max(
718 0,
719 int(np.floor((
720 self.priors['{}_time'.format(self.time_reference)].minimum
721 - earth_light_crossing_time
722 - self.interferometers.start_time
723 ) / time_space))
724 )
725 end_idx = min(
726 number_of_time_samples - 1,
727 int(np.ceil((
728 self.priors['{}_time'.format(self.time_reference)].maximum
729 + earth_light_crossing_time
730 - self.interferometers.start_time
731 ) / time_space))
732 )
733 self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space
734 logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples'])))
736 # select bases to be used, set prior ranges and frequency nodes if exist
737 idxs_in_prior_range = dict()
738 for basis_type, matrix in zip(['linear', 'quadratic'], [linear_matrix, quadratic_matrix]):
739 key = f'prior_range_{basis_type}'
740 if key in matrix:
741 prior_ranges = {}
742 for param_name in matrix[key]:
743 if 'roq_scale_power' in matrix[key][param_name].attrs:
744 roq_scale_factor = self.roq_scale_factor**matrix[key][param_name].attrs['roq_scale_power']
745 else:
746 roq_scale_factor = 1.
747 prior_ranges[param_name] = matrix[key][param_name][()] * roq_scale_factor
748 selected_idxs, selected_prior_ranges = self._select_prior_ranges(prior_ranges)
749 if len(selected_idxs) == 0:
750 raise BilbyROQParamsRangeError(f"There are no {basis_type} ROQ bases within the prior range.")
751 self.weights[key] = selected_prior_ranges
752 idxs_in_prior_range[basis_type] = selected_idxs
753 else:
754 idxs_in_prior_range[basis_type] = [0]
755 if 'frequency_nodes' in matrix[f'basis_{basis_type}'][str(idxs_in_prior_range[basis_type][0])]:
756 self.weights[f'frequency_nodes_{basis_type}'] = [
757 matrix[f'basis_{basis_type}'][str(i)]['frequency_nodes'][()] * self.roq_scale_factor
758 for i in idxs_in_prior_range[basis_type]]
760 if 'multiband_linear' in linear_matrix:
761 multiband_linear = linear_matrix['multiband_linear'][()]
762 else:
763 multiband_linear = False
764 if 'multiband_quadratic' in quadratic_matrix:
765 multiband_quadratic = quadratic_matrix['multiband_quadratic'][()]
766 else:
767 multiband_quadratic = False
769 # Get intersection between ifo and ROQ frequency samples. Required only for non-multibanded basis.
770 if not (multiband_linear and multiband_quadratic):
771 roq_idxs = {}
772 ifo_idxs = {}
773 for ifo in self.interferometers:
774 if self.roq_params is not None:
775 # Get scaled ROQ quantities
776 roq_scaled_minimum_frequency = self.roq_params['flow'] * self.roq_scale_factor
777 roq_scaled_maximum_frequency = self.roq_params['fhigh'] * self.roq_scale_factor
778 roq_scaled_segment_length = self.roq_params['seglen'] / self.roq_scale_factor
779 # Generate frequencies for the ROQ
780 roq_frequencies = create_frequency_series(
781 sampling_frequency=roq_scaled_maximum_frequency * 2,
782 duration=roq_scaled_segment_length)
783 roq_mask = roq_frequencies >= roq_scaled_minimum_frequency
784 roq_frequencies = roq_frequencies[roq_mask]
785 overlap_frequencies, ifo_idxs_this_ifo, roq_idxs_this_ifo = np.intersect1d(
786 ifo.frequency_array[ifo.frequency_mask], roq_frequencies,
787 return_indices=True)
788 else:
789 overlap_frequencies = ifo.frequency_array[ifo.frequency_mask]
790 roq_idxs_this_ifo = np.arange(
791 linear_matrix['basis_linear'][str(idxs_in_prior_range['linear'][0])]['basis'].shape[1],
792 dtype=int)
793 ifo_idxs_this_ifo = np.arange(sum(ifo.frequency_mask))
794 if len(ifo_idxs_this_ifo) != len(roq_idxs_this_ifo):
795 raise ValueError(
796 "Mismatch between ROQ basis and frequency array for "
797 "{}".format(ifo.name))
798 logger.info(
799 "Building ROQ weights for {} with {} frequencies between {} "
800 "and {}.".format(
801 ifo.name, len(overlap_frequencies),
802 min(overlap_frequencies), max(overlap_frequencies)))
803 roq_idxs[ifo.name] = roq_idxs_this_ifo
804 ifo_idxs[ifo.name] = ifo_idxs_this_ifo
806 if multiband_linear:
807 self._set_weights_linear_multiband(linear_matrix, idxs_in_prior_range['linear'])
808 else:
809 self._set_weights_linear(linear_matrix, idxs_in_prior_range['linear'], roq_idxs, ifo_idxs)
811 if multiband_quadratic:
812 self._set_weights_quadratic_multiband(quadratic_matrix, idxs_in_prior_range['quadratic'])
813 else:
814 self._set_weights_quadratic(quadratic_matrix, idxs_in_prior_range['quadratic'], roq_idxs, ifo_idxs)
816 def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs):
817 """
818 Setup the time-dependent linear ROQ weights. See https://dcc.ligo.org/LIGO-T2100125 for the detail of how to
819 compute them.
821 Parameters
822 ==========
823 linear_matrix : dictionary or h5py.File
824 linear basis
825 basis_idxs : array-like
826 indexes of bases used for a run
827 roq_idxs : dictionary
828 dictionary whose keys are interferometer names and values are indexes of basis components intersecting
829 frequency-domain data
830 ifo_idxs : dictionary
831 dictionary whose keys are interferometer names and values are indexes of frequency-domain data intersecting
832 basis components
834 """
835 for ifo in self.interferometers:
836 self.weights[ifo.name + '_linear'] = []
837 time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0]
838 number_of_time_samples = int(self.interferometers.duration / time_space)
839 start_idx = int(self.weights['time_samples'][0] / time_space)
840 end_idx = int(self.weights['time_samples'][-1] / time_space)
841 nonzero_idxs = {}
842 data_over_psd = {}
843 for ifo in self.interferometers:
844 nonzero_idxs[ifo.name] = ifo_idxs[ifo.name] + int(
845 ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration)
846 data_over_psd[ifo.name] = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] / \
847 ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]]
848 try:
849 import pyfftw
850 ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex)
851 ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex)
852 ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD')
853 except ImportError:
854 pyfftw = None
855 logger.warning("You do not have pyfftw installed, falling back to numpy.fft.")
856 ifft_input = np.zeros(number_of_time_samples, dtype=complex)
857 ifft = np.fft.ifft
858 for basis_idx in basis_idxs:
859 logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.")
860 linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis']
861 basis_size = linear_matrix_single.shape[0]
862 for ifo in self.interferometers:
863 ifft_input[:] *= 0.
864 linear_weights = \
865 np.zeros((len(self.weights['time_samples']), basis_size), dtype=complex)
866 for i in range(basis_size):
867 basis_element = linear_matrix_single[i][roq_idxs[ifo.name]]
868 ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element)
869 linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1]
870 linear_weights *= 4. * number_of_time_samples / self.interferometers.duration
871 self.weights[ifo.name + '_linear'].append(linear_weights)
872 if pyfftw is not None:
873 pyfftw.forget_wisdom()
875 def _set_weights_linear_multiband(self, linear_matrix, basis_idxs):
876 """
877 Setup the time-dependent linear ROQ weights from multibanded basis
879 Parameters
880 ==========
881 linear_matrix : dictionary or h5py.File
882 linear basis
883 basis_idxs : array-like
884 indexes of bases used for a run
886 """
887 for ifo in self.interferometers:
888 self.weights[ifo.name + '_linear'] = []
889 Tbs = linear_matrix['durations_s_linear'][()] / self.roq_scale_factor
890 start_end_frequency_bins = linear_matrix['start_end_frequency_bins_linear'][()]
891 basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1)
892 fhigh_basis = np.max(start_end_frequency_bins[:, 1] / Tbs)
893 # prepare time-shifted data, which is multiplied by basis
894 tc_shifted_data = dict()
895 for ifo in self.interferometers:
896 over_whitened_frequency_data = np.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex)
897 over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = \
898 ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask]
899 over_whitened_time_data = np.fft.irfft(over_whitened_frequency_data)
900 tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex)
901 start_idx_of_band = 0
902 for b, Tb in enumerate(Tbs):
903 start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b]
904 fs = np.arange(start_frequency_bin, end_frequency_bin + 1) / Tb
905 Db = np.fft.rfft(
906 over_whitened_time_data[-int(2. * fhigh_basis * Tb):]
907 )[start_frequency_bin:end_frequency_bin + 1]
908 start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1
909 tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp(
910 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - ifo.duration + Tb))
911 start_idx_of_band = start_idx_of_next_band
912 # compute inner products
913 for basis_idx in basis_idxs:
914 logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.")
915 linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'][()]
916 for ifo in self.interferometers:
917 self.weights[ifo.name + '_linear'].append(
918 np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T)
920 def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idxs):
921 """
922 Setup the quadratic ROQ weights
924 Parameters
925 ==========
926 quadratic_matrix : dictionary or h5py.File
927 quadratic basis
928 basis_idxs : array-like
929 indexes of bases used for a run
930 roq_idxs : dictionary
931 dictionary whose keys are interferometer names and values are indexes of basis components intersecting
932 frequency-domain data
933 ifo_idxs : dictionary
934 dictionary whose keys are interferometer names and values are indexes of frequency-domain data intersecting
935 basis components
937 """
938 for ifo in self.interferometers:
939 self.weights[ifo.name + '_quadratic'] = []
940 for basis_idx in basis_idxs:
941 logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.")
942 quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real
943 for ifo in self.interferometers:
944 self.weights[ifo.name + '_quadratic'].append(
945 4. / ifo.strain_data.duration * np.dot(
946 quadratic_matrix_single[:, roq_idxs[ifo.name]],
947 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]]))
948 del quadratic_matrix_single
950 def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs):
951 """
952 Setup the quadratic ROQ weights from multibanded basis
954 Parameters
955 ==========
956 quadratic_matrix : dictionary or h5py.File
957 quadratic basis
958 basis_idxs : array-like
959 indexes of bases used for a run
961 """
962 for ifo in self.interferometers:
963 self.weights[ifo.name + '_quadratic'] = []
964 Tbs = quadratic_matrix['durations_s_quadratic'][()] / self.roq_scale_factor
965 start_end_frequency_bins = quadratic_matrix['start_end_frequency_bins_quadratic'][()]
966 basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1)
967 fhigh_basis = np.max(start_end_frequency_bins[:, 1] / Tbs)
968 # prepare coefficients multiplied by basis
969 multibanded_inverse_psd = dict()
970 for ifo in self.interferometers:
971 inverse_psd_frequency = np.zeros(int(fhigh_basis * ifo.duration) + 1)
972 inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = \
973 1. / ifo.power_spectral_density_array[ifo.frequency_mask]
974 inverse_psd_time = np.fft.irfft(inverse_psd_frequency)
975 multibanded_inverse_psd[ifo.name] = np.zeros(basis_dimension)
976 start_idx_of_band = 0
977 for b, Tb in enumerate(Tbs):
978 start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b]
979 number_of_samples_half = int(fhigh_basis * Tb)
980 start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1
981 multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * np.fft.rfft(
982 np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:])
983 )[start_frequency_bin:end_frequency_bin + 1].real
984 start_idx_of_band = start_idx_of_next_band
985 # compute inner products
986 for basis_idx in basis_idxs:
987 logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.")
988 quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real
989 for ifo in self.interferometers:
990 self.weights[ifo.name + '_quadratic'].append(
991 np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name]))
993 def save_weights(self, filename, format='hdf5'):
994 """
995 Save ROQ weights into a single file. format should be npz, or hdf5.
996 For weights from multiple bases, hdf5 is only the possible option.
997 Support for json format is deprecated as of :code:`v2.1` and will be
998 removed in :code:`v2.2`, another method should be used by default.
1000 Parameters
1001 ==========
1002 filename : str
1003 The name of the file to save the weights to.
1004 format : str
1005 The format to save the data to, this should be one of
1006 :code:`"hdf5"`, :code:`"npz"`, default=:code:`"hdf5"`.
1007 """
1008 if format not in ['json', 'npz', 'hdf5']:
1009 raise IOError(f"Format {format} not recognized.")
1010 if format == "json":
1011 import warnings
1013 warnings.warn(
1014 "json format for ROQ weights is deprecated, use hdf5 instead.",
1015 DeprecationWarning
1016 )
1017 if format not in filename:
1018 filename += "." + format
1019 logger.info(f"Saving ROQ weights to {filename}")
1020 if format == 'json' or format == 'npz':
1021 if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1:
1022 raise ValueError(f'Format {format} not compatible with multiple bases')
1023 weights = dict()
1024 weights['time_samples'] = self.weights['time_samples']
1025 for basis_type in ['linear', 'quadratic']:
1026 for ifo in self.interferometers:
1027 key = f'{ifo.name}_{basis_type}'
1028 weights[key] = self.weights[key][0]
1029 if format == 'json':
1030 with open(filename, 'w') as file:
1031 json.dump(weights, file, indent=2, cls=BilbyJsonEncoder)
1032 else:
1033 np.savez(filename, **weights)
1034 else:
1035 import h5py
1036 with h5py.File(filename, 'w') as f:
1037 f.create_dataset('time_samples',
1038 data=self.weights['time_samples'])
1039 for basis_type in ['linear', 'quadratic']:
1040 key = f'prior_range_{basis_type}'
1041 if key in self.weights:
1042 grp = f.create_group(key)
1043 for param_name in self.weights[key]:
1044 grp.create_dataset(
1045 param_name, data=self.weights[key][param_name])
1046 key = f'frequency_nodes_{basis_type}'
1047 if key in self.weights:
1048 grp = f.create_group(key)
1049 for i in range(len(self.weights[key])):
1050 grp.create_dataset(
1051 str(i), data=self.weights[key][i])
1052 for ifo in self.interferometers:
1053 key = f"{ifo.name}_{basis_type}"
1054 grp = f.create_group(key)
1055 for i in range(len(self.weights[key])):
1056 grp.create_dataset(
1057 str(i), data=self.weights[key][i])
1059 def load_weights(self, filename, format=None):
1060 """
1061 Load ROQ weights. format should be json, npz, or hdf5.
1062 json or npz file is assumed to contain weights from a single basis.
1063 Support for json format is deprecated as of :code:`v2.1` and will be
1064 removed in :code:`v2.2`, another method should be used by default.
1066 Parameters
1067 ==========
1068 filename : str
1069 The name of the file to save the weights to.
1070 format : str
1071 The format to save the data to, this should be one of
1072 :code:`"hdf5"`, :code:`"npz"`, default=:code:`"hdf5"`.
1074 Returns
1075 =======
1076 weights: dict
1077 Dictionary containing the ROQ weights.
1078 """
1079 if format is None:
1080 format = filename.split(".")[-1]
1081 if format not in ["json", "npz", "hdf5"]:
1082 raise IOError(f"Format {format} not recognized.")
1083 if format == "json":
1084 import warnings
1086 warnings.warn(
1087 "json format for ROQ weights is deprecated, use hdf5 instead.",
1088 DeprecationWarning
1089 )
1090 logger.info(f"Loading ROQ weights from {filename}")
1091 if format == "json" or format == "npz":
1092 # Old file format assumed to contain only a single basis
1093 if format == "json":
1094 with open(filename, 'r') as file:
1095 weights = json.load(file, object_hook=decode_bilby_json)
1096 else:
1097 # Wrap in dict to load data into memory
1098 weights = dict(np.load(filename))
1099 for basis_type in ['linear', 'quadratic']:
1100 for ifo in self.interferometers:
1101 key = f'{ifo.name}_{basis_type}'
1102 weights[key] = [weights[key]]
1103 else:
1104 weights = dict()
1105 import h5py
1106 with h5py.File(filename, 'r') as f:
1107 weights['time_samples'] = f['time_samples'][()]
1108 for basis_type in ['linear', 'quadratic']:
1109 key = f'prior_range_{basis_type}'
1110 if key in f:
1111 idxs_in_prior_range, selected_prior_ranges = \
1112 self._select_prior_ranges(f[key])
1113 weights[key] = selected_prior_ranges
1114 else:
1115 idxs_in_prior_range = [0]
1116 key = f'frequency_nodes_{basis_type}'
1117 if key in f:
1118 weights[key] = [f[key][str(i)][()]
1119 for i in idxs_in_prior_range]
1120 for ifo in self.interferometers:
1121 key = f"{ifo.name}_{basis_type}"
1122 weights[key] = [f[key][str(i)][()]
1123 for i in idxs_in_prior_range]
1124 return weights
1126 def _get_time_resolution(self):
1127 """
1128 This method estimates the time resolution given the optimal SNR of the
1129 signal in the detector. This is then used when constructing the weights
1130 for the ROQ.
1132 A minimum resolution is set by assuming the SNR in each detector is at
1133 least 10. When the SNR is not available the SNR is assumed to be 30 in
1134 each detector.
1136 Returns
1137 =======
1138 delta_t: float
1139 Time resolution
1140 """
1142 def calc_fhigh(freq, psd, scaling=20.):
1143 """
1145 Parameters
1146 ==========
1147 freq: array-like
1148 Frequency array
1149 psd: array-like
1150 Power spectral density
1151 scaling: float
1152 SNR dependent scaling factor
1154 Returns
1155 =======
1156 f_high: float
1157 The maximum frequency which must be considered
1158 """
1159 from scipy.integrate import simpson
1160 integrand1 = np.power(freq, -7. / 3) / psd
1161 integral1 = simpson(y=integrand1, x=freq)
1162 integrand3 = np.power(freq, 2. / 3.) / (psd * integral1)
1163 f_3_bar = simpson(y=integrand3, x=freq)
1165 f_high = scaling * f_3_bar**(1 / 3)
1167 return f_high
1169 def c_f_scaling(snr):
1170 return (np.pi**2 * snr**2 / 6)**(1 / 3)
1172 inj_snr_sq = 0
1173 for ifo in self.interferometers:
1174 inj_snr_sq += max(10, ifo.meta_data.get('optimal_SNR', 30))**2
1176 psd = ifo.power_spectral_density_array[ifo.frequency_mask]
1177 freq = ifo.frequency_array[ifo.frequency_mask]
1178 fhigh = calc_fhigh(freq, psd, scaling=c_f_scaling(inj_snr_sq**0.5))
1180 delta_t = fhigh**-1
1182 # Apply a safety factor to ensure the time step is short enough
1183 delta_t = delta_t / 5
1185 # duration / delta_t needs to be a power of 2 for IFFT
1186 number_of_time_samples = max(
1187 self.interferometers.duration / delta_t,
1188 self.interferometers.frequency_array[-1] * self.interferometers.duration + 1)
1189 number_of_time_samples = int(2**np.ceil(np.log2(number_of_time_samples)))
1190 delta_t = self.interferometers.duration / number_of_time_samples
1191 logger.info("ROQ time-step = {}".format(delta_t))
1192 return delta_t
1194 def _rescale_signal(self, signal, new_distance):
1195 for kind in ['linear', 'quadratic']:
1196 for mode in signal[kind]:
1197 signal[kind][mode] *= self._ref_dist / new_distance
1199 def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations=None):
1200 from ...core.utils.random import rng
1202 self.parameters.update(self.get_sky_frame_parameters())
1203 if signal_polarizations is None:
1204 signal_polarizations = \
1205 self.waveform_generator.frequency_domain_strain(self.parameters)
1207 snrs = self._CalculatedSNRs()
1209 for interferometer in self.interferometers:
1210 snrs += self.calculate_snrs(
1211 waveform_polarizations=signal_polarizations,
1212 interferometer=interferometer
1213 )
1214 d_inner_h = snrs.d_inner_h_array
1215 h_inner_h = snrs.optimal_snr_squared
1217 if self.distance_marginalization:
1218 time_log_like = self.distance_marginalized_likelihood(
1219 d_inner_h, h_inner_h)
1220 elif self.phase_marginalization:
1221 time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2
1222 else:
1223 time_log_like = (d_inner_h.real - h_inner_h.real / 2)
1225 times = self._times
1226 if self.jitter_time:
1227 times = times + self.parameters["time_jitter"]
1228 time_prior_array = self.priors['geocent_time'].prob(times)
1229 time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array
1230 time_post /= np.sum(time_post)
1231 return rng.choice(times, p=time_post)
1234class BilbyROQParamsRangeError(Exception):
1235 pass