Coverage for bilby/gw/likelihood/base.py: 86%
534 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
2import os
3import copy
5import attr
6import numpy as np
7from scipy.special import logsumexp
9from ...core.likelihood import Likelihood
10from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series
11from ...core.prior import Interped, Prior, Uniform, DeltaFunction
12from ..detector import InterferometerList, get_empty_interferometer, calibration
13from ..prior import BBHPriorDict, Cosmological
14from ..utils import noise_weighted_inner_product, zenith_azimuth_to_ra_dec, ln_i0
17class GravitationalWaveTransient(Likelihood):
18 """ A gravitational-wave transient likelihood object
20 This is the usual likelihood object to use for transient gravitational
21 wave parameter estimation. It computes the log-likelihood in the frequency
22 domain assuming a colored Gaussian noise model described by a power
23 spectral density. See Thrane & Talbot (2019), arxiv.org/abs/1809.02293.
25 Parameters
26 ==========
27 interferometers: list, bilby.gw.detector.InterferometerList
28 A list of `bilby.detector.Interferometer` instances - contains the
29 detector data and power spectral densities
30 waveform_generator: `bilby.waveform_generator.WaveformGenerator`
31 An object which computes the frequency-domain strain of the signal,
32 given some set of parameters
33 distance_marginalization: bool, optional
34 If true, marginalize over distance in the likelihood.
35 This uses a look up table calculated at run time.
36 The distance prior is set to be a delta function at the minimum
37 distance allowed in the prior being marginalised over.
38 time_marginalization: bool, optional
39 If true, marginalize over time in the likelihood.
40 This uses a FFT to calculate the likelihood over a regularly spaced
41 grid.
42 In order to cover the whole space the prior is set to be uniform over
43 the spacing of the array of times.
44 If using time marginalisation and jitter_time is True a "jitter"
45 parameter is added to the prior which modifies the position of the
46 grid of times.
47 phase_marginalization: bool, optional
48 If true, marginalize over phase in the likelihood.
49 This is done analytically using a Bessel function.
50 The phase prior is set to be a delta function at phase=0.
51 calibration_marginalization: bool, optional
52 If true, marginalize over calibration response curves in the likelihood.
53 This is done numerically over a number of calibration response curve realizations.
54 priors: dict, optional
55 If given, used in the distance and phase marginalization.
56 Warning: when using marginalisation the dict is overwritten which will change the
57 the dict you are passing in. If this behaviour is undesired, pass `priors.copy()`.
58 distance_marginalization_lookup_table: (dict, str), optional
59 If a dict, dictionary containing the lookup_table, distance_array,
60 (distance) prior_array, and reference_distance used to construct
61 the table.
62 If a string the name of a file containing these quantities.
63 The lookup table is stored after construction in either the
64 provided string or a default location:
65 '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
66 calibration_lookup_table: dict, optional
67 If a dict, contains the arrays over which to marginalize for each interferometer or the filepaths of the
68 calibration files.
69 If not provided, but calibration_marginalization is used, then the appropriate file is created to
70 contain the curves.
71 number_of_response_curves: int, optional
72 Number of curves from the calibration lookup table to use.
73 Default is 1000.
74 starting_index: int, optional
75 Sets the index for the first realization of the calibration curve to be considered.
76 This, coupled with number_of_response_curves, allows for restricting the set of curves used. This can be used
77 when dealing with large frequency arrays to split the calculation into sections.
78 Defaults to 0.
79 jitter_time: bool, optional
80 Whether to introduce a `time_jitter` parameter. This avoids either
81 missing the likelihood peak, or introducing biases in the
82 reconstructed time posterior due to an insufficient sampling frequency.
83 Default is False, however using this parameter is strongly encouraged.
84 reference_frame: (str, bilby.gw.detector.InterferometerList, list), optional
85 Definition of the reference frame for the sky location.
87 - :code:`sky`: sample in RA/dec, this is the default
88 - e.g., :code:`"H1L1", ["H1", "L1"], InterferometerList(["H1", "L1"])`:
89 sample in azimuth and zenith, `azimuth` and `zenith` defined in the
90 frame where the z-axis is aligned the the vector connecting H1
91 and L1.
93 time_reference: str, optional
94 Name of the reference for the sampled time parameter.
96 - :code:`geocent`/:code:`geocenter`: sample in the time at the
97 Earth's center, this is the default
98 - e.g., :code:`H1`: sample in the time of arrival at H1
100 Returns
101 =======
102 Likelihood: `bilby.core.likelihood.Likelihood`
103 A likelihood object, able to compute the likelihood of the data given
104 some model parameters
106 """
108 @attr.s(slots=True, weakref_slot=False)
109 class _CalculatedSNRs:
110 d_inner_h = attr.ib(default=0j, converter=complex)
111 optimal_snr_squared = attr.ib(default=0, converter=float)
112 complex_matched_filter_snr = attr.ib(default=0j, converter=complex)
113 d_inner_h_array = attr.ib(default=None)
114 optimal_snr_squared_array = attr.ib(default=None)
116 def __add__(self, other_snr):
117 new = copy.deepcopy(self)
118 new += other_snr
119 return new
121 def __iadd__(self, other_snr):
122 for key in self.__slots__:
123 this = getattr(self, key)
124 other = getattr(other_snr, key)
125 if this is not None and other is not None:
126 setattr(self, key, this + other)
127 elif this is None:
128 setattr(self, key, other)
129 return self
131 @property
132 def snrs_as_sample(self) -> dict:
133 """Get the SNRs of this object as a sample dictionary
135 Returns
136 =======
137 dict
138 The dictionary of SNRs labelled accordingly
139 """
140 return {
141 "matched_filter_snr" : self.complex_matched_filter_snr,
142 "optimal_snr" : self.optimal_snr_squared.real ** 0.5
143 }
145 def __init__(
146 self, interferometers, waveform_generator, time_marginalization=False,
147 distance_marginalization=False, phase_marginalization=False, calibration_marginalization=False, priors=None,
148 distance_marginalization_lookup_table=None, calibration_lookup_table=None,
149 number_of_response_curves=1000, starting_index=0, jitter_time=True, reference_frame="sky",
150 time_reference="geocenter"
151 ):
153 self.waveform_generator = waveform_generator
154 super(GravitationalWaveTransient, self).__init__(dict())
155 self.interferometers = InterferometerList(interferometers)
156 self.time_marginalization = time_marginalization
157 self.distance_marginalization = distance_marginalization
158 self.phase_marginalization = phase_marginalization
159 self.calibration_marginalization = calibration_marginalization
160 self.priors = priors
161 self._check_set_duration_and_sampling_frequency_of_waveform_generator()
162 self._noise_log_likelihood_value = None
163 self.jitter_time = jitter_time
164 self.reference_frame = reference_frame
165 if "geocent" not in time_reference:
166 self.time_reference = time_reference
167 self.reference_ifo = get_empty_interferometer(self.time_reference)
168 if self.time_marginalization:
169 logger.info("Cannot marginalise over non-geocenter time.")
170 self.time_marginalization = False
171 self.jitter_time = False
172 else:
173 self.time_reference = "geocent"
174 self.reference_ifo = None
176 if self.time_marginalization:
177 self._check_marginalized_prior_is_set(key='geocent_time')
178 self._setup_time_marginalization()
179 priors['geocent_time'] = float(self.interferometers.start_time)
180 if self.jitter_time:
181 priors['time_jitter'] = Uniform(
182 minimum=- self._delta_tc / 2,
183 maximum=self._delta_tc / 2,
184 boundary='periodic',
185 name="time_jitter",
186 latex_label="$t_j$"
187 )
188 self._marginalized_parameters.append('geocent_time')
189 elif self.jitter_time:
190 logger.debug(
191 "Time jittering requested with non-time-marginalised "
192 "likelihood, ignoring.")
193 self.jitter_time = False
195 if self.phase_marginalization:
196 self._check_marginalized_prior_is_set(key='phase')
197 priors['phase'] = float(0)
198 self._marginalized_parameters.append('phase')
200 if self.distance_marginalization:
201 self._lookup_table_filename = None
202 self._check_marginalized_prior_is_set(key='luminosity_distance')
203 self._distance_array = np.linspace(
204 self.priors['luminosity_distance'].minimum,
205 self.priors['luminosity_distance'].maximum, int(1e4))
206 self.distance_prior_array = np.array(
207 [self.priors['luminosity_distance'].prob(distance)
208 for distance in self._distance_array])
209 self._ref_dist = self.priors['luminosity_distance'].rescale(0.5)
210 self._setup_distance_marginalization(
211 distance_marginalization_lookup_table)
212 for key in ['redshift', 'comoving_distance']:
213 if key in priors:
214 del priors[key]
215 priors['luminosity_distance'] = float(self._ref_dist)
216 self._marginalized_parameters.append('luminosity_distance')
218 if self.calibration_marginalization:
219 self.number_of_response_curves = number_of_response_curves
220 self.starting_index = starting_index
221 self._setup_calibration_marginalization(calibration_lookup_table, priors)
222 self._marginalized_parameters.append('recalib_index')
224 def __repr__(self):
225 return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\ttime_marginalization={}, ' \
226 'distance_marginalization={}, phase_marginalization={}, ' \
227 'calibration_marginalization={}, priors={})' \
228 .format(self.interferometers, self.waveform_generator, self.time_marginalization,
229 self.distance_marginalization, self.phase_marginalization, self.calibration_marginalization,
230 self.priors)
232 def _check_set_duration_and_sampling_frequency_of_waveform_generator(self):
233 """ Check the waveform_generator has the same duration and
234 sampling_frequency as the interferometers. If they are unset, then
235 set them, if they differ, raise an error
236 """
238 attributes = ['duration', 'sampling_frequency', 'start_time']
239 for attribute in attributes:
240 wfg_attr = getattr(self.waveform_generator, attribute)
241 ifo_attr = getattr(self.interferometers, attribute)
242 if wfg_attr is None:
243 logger.debug(
244 "The waveform_generator {} is None. Setting from the "
245 "provided interferometers.".format(attribute))
246 elif wfg_attr != ifo_attr:
247 logger.debug(
248 "The waveform_generator {} is not equal to that of the "
249 "provided interferometers. Overwriting the "
250 "waveform_generator.".format(attribute))
251 setattr(self.waveform_generator, attribute, ifo_attr)
253 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True):
254 """
255 Compute the snrs
257 Parameters
258 ----------
259 waveform_polarizations: dict
260 A dictionary of waveform polarizations and the corresponding array
261 interferometer: bilby.gw.detector.Interferometer
262 The bilby interferometer object
263 return_array: bool
264 If true, calculate and return internal array objects
265 (d_inner_h_array and optimal_snr_squared_array), otherwise
266 these are returned as None.
268 Returns
269 -------
270 calculated_snrs: _CalculatedSNRs
271 An object containing the SNR quantities and (if return_array=True)
272 the internal array objects.
274 """
275 signal = self._compute_full_waveform(
276 signal_polarizations=waveform_polarizations,
277 interferometer=interferometer,
278 )
279 _mask = interferometer.frequency_mask
281 if 'recalib_index' in self.parameters:
282 signal[_mask] *= self.calibration_draws[interferometer.name][int(self.parameters['recalib_index'])]
284 d_inner_h = interferometer.inner_product(signal=signal)
285 optimal_snr_squared = interferometer.optimal_snr_squared(signal=signal)
286 complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5)
288 d_inner_h_array = None
289 optimal_snr_squared_array = None
291 normalization = 4 / self.waveform_generator.duration
293 if return_array is False:
294 d_inner_h_array = None
295 optimal_snr_squared_array = None
296 elif self.time_marginalization and self.calibration_marginalization:
298 d_inner_h_integrand = np.tile(
299 interferometer.frequency_domain_strain.conjugate() * signal /
300 interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T
302 d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T
304 d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft(
305 d_inner_h_integrand[0:-1], axis=0
306 ).T
308 optimal_snr_squared_integrand = (
309 normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array
310 )
311 optimal_snr_squared_array = np.dot(
312 optimal_snr_squared_integrand[_mask],
313 self.calibration_abs_draws[interferometer.name].T
314 )
316 elif self.time_marginalization and not self.calibration_marginalization:
317 d_inner_h_array = normalization * np.fft.fft(
318 signal[0:-1]
319 * interferometer.frequency_domain_strain.conjugate()[0:-1]
320 / interferometer.power_spectral_density_array[0:-1]
321 )
323 elif self.calibration_marginalization and ('recalib_index' not in self.parameters):
324 d_inner_h_integrand = (
325 normalization *
326 interferometer.frequency_domain_strain.conjugate() * signal
327 / interferometer.power_spectral_density_array
328 )
329 d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T)
331 optimal_snr_squared_integrand = (
332 normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array
333 )
334 optimal_snr_squared_array = np.dot(
335 optimal_snr_squared_integrand[_mask],
336 self.calibration_abs_draws[interferometer.name].T
337 )
339 return self._CalculatedSNRs(
340 d_inner_h=d_inner_h,
341 optimal_snr_squared=optimal_snr_squared.real,
342 complex_matched_filter_snr=complex_matched_filter_snr,
343 d_inner_h_array=d_inner_h_array,
344 optimal_snr_squared_array=optimal_snr_squared_array,
345 )
347 def _check_marginalized_prior_is_set(self, key):
348 if key in self.priors and self.priors[key].is_fixed:
349 raise ValueError(
350 "Cannot use marginalized likelihood for {}: prior is fixed".format(key)
351 )
352 if key not in self.priors or not isinstance(
353 self.priors[key], Prior):
354 logger.warning(
355 'Prior not provided for {}, using the BBH default.'.format(key))
356 if key == 'geocent_time':
357 self.priors[key] = Uniform(
358 self.interferometers.start_time,
359 self.interferometers.start_time + self.interferometers.duration)
360 elif key == 'luminosity_distance':
361 for key in ['redshift', 'comoving_distance']:
362 if key in self.priors:
363 if not isinstance(self.priors[key], Cosmological):
364 raise TypeError(
365 "To marginalize over {}, the prior must be specified as a "
366 "subclass of bilby.gw.prior.Cosmological.".format(key)
367 )
368 self.priors['luminosity_distance'] = self.priors[key].get_corresponding_prior(
369 'luminosity_distance'
370 )
371 del self.priors[key]
372 else:
373 self.priors[key] = BBHPriorDict()[key]
375 @property
376 def priors(self):
377 return self._prior
379 @priors.setter
380 def priors(self, priors):
381 if priors is not None:
382 self._prior = priors.copy()
383 elif any([self.time_marginalization, self.phase_marginalization,
384 self.distance_marginalization]):
385 raise ValueError("You can't use a marginalized likelihood without specifying a priors")
386 else:
387 self._prior = None
389 def _calculate_noise_log_likelihood(self):
390 log_l = 0
391 for interferometer in self.interferometers:
392 mask = interferometer.frequency_mask
393 log_l -= noise_weighted_inner_product(
394 interferometer.frequency_domain_strain[mask],
395 interferometer.frequency_domain_strain[mask],
396 interferometer.power_spectral_density_array[mask],
397 self.waveform_generator.duration) / 2
398 return float(np.real(log_l))
400 def noise_log_likelihood(self):
401 # only compute likelihood if called for the 1st time
402 if self._noise_log_likelihood_value is None:
403 self._noise_log_likelihood_value = self._calculate_noise_log_likelihood()
404 return self._noise_log_likelihood_value
406 def log_likelihood_ratio(self):
407 waveform_polarizations = \
408 self.waveform_generator.frequency_domain_strain(self.parameters)
409 if waveform_polarizations is None:
410 return np.nan_to_num(-np.inf)
412 if self.time_marginalization and self.jitter_time:
413 self.parameters['geocent_time'] += self.parameters['time_jitter']
415 self.parameters.update(self.get_sky_frame_parameters())
417 total_snrs = self._CalculatedSNRs()
419 for interferometer in self.interferometers:
420 per_detector_snr = self.calculate_snrs(
421 waveform_polarizations=waveform_polarizations,
422 interferometer=interferometer)
424 total_snrs += per_detector_snr
426 log_l = self.compute_log_likelihood_from_snrs(total_snrs)
428 if self.time_marginalization and self.jitter_time:
429 self.parameters['geocent_time'] -= self.parameters['time_jitter']
431 return float(log_l.real)
433 def compute_log_likelihood_from_snrs(self, total_snrs):
435 if self.calibration_marginalization:
436 log_l = self.calibration_marginalized_likelihood(
437 d_inner_h_calibration_array=total_snrs.d_inner_h_array,
438 h_inner_h=total_snrs.optimal_snr_squared_array)
440 elif self.time_marginalization:
441 log_l = self.time_marginalized_likelihood(
442 d_inner_h_tc_array=total_snrs.d_inner_h_array,
443 h_inner_h=total_snrs.optimal_snr_squared)
445 elif self.distance_marginalization:
446 log_l = self.distance_marginalized_likelihood(
447 d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared)
449 elif self.phase_marginalization:
450 log_l = self.phase_marginalized_likelihood(
451 d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared)
453 else:
454 log_l = np.real(total_snrs.d_inner_h) - total_snrs.optimal_snr_squared / 2
456 return log_l
458 def compute_per_detector_log_likelihood(self):
459 waveform_polarizations = \
460 self.waveform_generator.frequency_domain_strain(self.parameters)
462 if self.time_marginalization and self.jitter_time:
463 self.parameters['geocent_time'] += self.parameters['time_jitter']
465 self.parameters.update(self.get_sky_frame_parameters())
467 for interferometer in self.interferometers:
468 per_detector_snr = self.calculate_snrs(
469 waveform_polarizations=waveform_polarizations,
470 interferometer=interferometer)
472 self.parameters['{}_log_likelihood'.format(interferometer.name)] = \
473 self.compute_log_likelihood_from_snrs(per_detector_snr)
475 if self.time_marginalization and self.jitter_time:
476 self.parameters['geocent_time'] -= self.parameters['time_jitter']
478 return self.parameters.copy()
480 def generate_posterior_sample_from_marginalized_likelihood(self):
481 """
482 Reconstruct the distance posterior from a run which used a likelihood
483 which explicitly marginalised over time/distance/phase.
485 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293
487 Returns
488 =======
489 sample: dict
490 Returns the parameters with new samples.
492 Notes
493 =====
494 This involves a deepcopy of the signal to avoid issues with waveform
495 caching, as the signal is overwritten in place.
496 """
497 if len(self._marginalized_parameters) > 0:
498 signal_polarizations = copy.deepcopy(
499 self.waveform_generator.frequency_domain_strain(
500 self.parameters))
501 else:
502 return self.parameters
504 if self.calibration_marginalization:
505 new_calibration = self.generate_calibration_sample_from_marginalized_likelihood(
506 signal_polarizations=signal_polarizations)
507 self.parameters['recalib_index'] = new_calibration
508 if self.time_marginalization:
509 new_time = self.generate_time_sample_from_marginalized_likelihood(
510 signal_polarizations=signal_polarizations)
511 self.parameters['geocent_time'] = new_time
512 if self.distance_marginalization:
513 new_distance = self.generate_distance_sample_from_marginalized_likelihood(
514 signal_polarizations=signal_polarizations)
515 self.parameters['luminosity_distance'] = new_distance
516 if self.phase_marginalization:
517 new_phase = self.generate_phase_sample_from_marginalized_likelihood(
518 signal_polarizations=signal_polarizations)
519 self.parameters['phase'] = new_phase
520 return self.parameters.copy()
522 def generate_calibration_sample_from_marginalized_likelihood(
523 self, signal_polarizations=None):
524 """
525 Generate a single sample from the posterior distribution for the set of calibration response curves when
526 explicitly marginalizing over the calibration uncertainty.
528 Parameters
529 ----------
530 signal_polarizations: dict, optional
531 Polarizations modes of the template.
533 Returns
534 -------
535 new_calibration: dict
536 Sample set from the calibration posterior
537 """
538 from ...core.utils.random import rng
540 if 'recalib_index' in self.parameters:
541 self.parameters.pop('recalib_index')
542 self.parameters.update(self.get_sky_frame_parameters())
543 if signal_polarizations is None:
544 signal_polarizations = \
545 self.waveform_generator.frequency_domain_strain(self.parameters)
547 log_like = self.get_calibration_log_likelihoods(signal_polarizations=signal_polarizations)
549 calibration_post = np.exp(log_like - max(log_like))
550 calibration_post /= np.sum(calibration_post)
552 new_calibration = rng.choice(self.number_of_response_curves, p=calibration_post)
554 return new_calibration
556 def generate_time_sample_from_marginalized_likelihood(
557 self, signal_polarizations=None):
558 """
559 Generate a single sample from the posterior distribution for coalescence
560 time when using a likelihood which explicitly marginalises over time.
562 In order to resolve the posterior we artificially upsample to 16kHz.
564 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293
566 Parameters
567 ==========
568 signal_polarizations: dict, optional
569 Polarizations modes of the template.
571 Returns
572 =======
573 new_time: float
574 Sample from the time posterior.
575 """
576 self.parameters.update(self.get_sky_frame_parameters())
577 if self.jitter_time:
578 self.parameters['geocent_time'] += self.parameters['time_jitter']
579 if signal_polarizations is None:
580 signal_polarizations = \
581 self.waveform_generator.frequency_domain_strain(self.parameters)
583 times = create_time_series(
584 sampling_frequency=16384,
585 starting_time=self.parameters['geocent_time'] - self.waveform_generator.start_time,
586 duration=self.waveform_generator.duration)
587 times = times % self.waveform_generator.duration
588 times += self.waveform_generator.start_time
590 prior = self.priors["geocent_time"]
591 in_prior = (times >= prior.minimum) & (times < prior.maximum)
592 times = times[in_prior]
594 n_time_steps = int(self.waveform_generator.duration * 16384)
595 d_inner_h = np.zeros(len(times), dtype=complex)
596 psd = np.ones(n_time_steps)
597 signal_long = np.zeros(n_time_steps, dtype=complex)
598 data = np.zeros(n_time_steps, dtype=complex)
599 h_inner_h = np.zeros(1)
600 for ifo in self.interferometers:
601 ifo_length = len(ifo.frequency_domain_strain)
602 mask = ifo.frequency_mask
603 signal = self._compute_full_waveform(
604 signal_polarizations=signal_polarizations,
605 interferometer=ifo,
606 )
607 signal_long[:ifo_length] = signal
608 data[:ifo_length] = np.conj(ifo.frequency_domain_strain)
609 psd[:ifo_length][mask] = ifo.power_spectral_density_array[mask]
610 d_inner_h += np.fft.fft(signal_long * data / psd)[in_prior]
611 h_inner_h += ifo.optimal_snr_squared(signal=signal).real
613 if self.distance_marginalization:
614 time_log_like = self.distance_marginalized_likelihood(
615 d_inner_h, h_inner_h)
616 elif self.phase_marginalization:
617 time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2
618 else:
619 time_log_like = (d_inner_h.real - h_inner_h.real / 2)
621 time_prior_array = self.priors['geocent_time'].prob(times)
622 time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array
624 keep = (time_post > max(time_post) / 1000)
625 if sum(keep) < 3:
626 keep[1:-1] = keep[1:-1] | keep[2:] | keep[:-2]
627 time_post = time_post[keep]
628 times = times[keep]
630 new_time = Interped(times, time_post).sample()
631 return new_time
633 def generate_distance_sample_from_marginalized_likelihood(
634 self, signal_polarizations=None):
635 """
636 Generate a single sample from the posterior distribution for luminosity
637 distance when using a likelihood which explicitly marginalises over
638 distance.
640 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293
642 Parameters
643 ==========
644 signal_polarizations: dict, optional
645 Polarizations modes of the template.
646 Note: These are rescaled in place after the distance sample is
647 generated to allow further parameter reconstruction to occur.
649 Returns
650 =======
651 new_distance: float
652 Sample from the distance posterior.
653 """
654 self.parameters.update(self.get_sky_frame_parameters())
655 if signal_polarizations is None:
656 signal_polarizations = \
657 self.waveform_generator.frequency_domain_strain(self.parameters)
659 d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations)
661 d_inner_h_dist = (
662 d_inner_h * self.parameters['luminosity_distance'] / self._distance_array
663 )
665 h_inner_h_dist = (
666 h_inner_h * self.parameters['luminosity_distance']**2 / self._distance_array**2
667 )
669 if self.phase_marginalization:
670 distance_log_like = ln_i0(abs(d_inner_h_dist)) - h_inner_h_dist.real / 2
671 else:
672 distance_log_like = (d_inner_h_dist.real - h_inner_h_dist.real / 2)
674 distance_post = (np.exp(distance_log_like - max(distance_log_like)) *
675 self.distance_prior_array)
677 new_distance = Interped(
678 self._distance_array, distance_post).sample()
680 self._rescale_signal(signal_polarizations, new_distance)
681 return new_distance
683 def _calculate_inner_products(self, signal_polarizations):
684 d_inner_h = 0
685 h_inner_h = 0
686 for interferometer in self.interferometers:
687 per_detector_snr = self.calculate_snrs(
688 signal_polarizations, interferometer)
690 d_inner_h += per_detector_snr.d_inner_h
691 h_inner_h += per_detector_snr.optimal_snr_squared
692 return d_inner_h, h_inner_h
694 def _compute_full_waveform(self, signal_polarizations, interferometer):
695 """
696 Project the waveform polarizations against the interferometer
697 response. This is useful for likelihood classes that don't
698 use the full frequency array, e.g., the relative binning
699 likelihood.
701 Parameters
702 ==========
703 signal_polarizations: dict
704 Dictionary containing the waveform evaluated at
705 :code:`interferometer.frequency_array`.
706 interferometer: bilby.gw.detector.Interferometer
707 Interferometer to compute the response with respect to.
708 """
709 return interferometer.get_detector_response(signal_polarizations, self.parameters)
711 def generate_phase_sample_from_marginalized_likelihood(
712 self, signal_polarizations=None):
713 r"""
714 Generate a single sample from the posterior distribution for phase when
715 using a likelihood which explicitly marginalises over phase.
717 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293
719 Parameters
720 ==========
721 signal_polarizations: dict, optional
722 Polarizations modes of the template.
724 Returns
725 =======
726 new_phase: float
727 Sample from the phase posterior.
729 Notes
730 =====
731 This is only valid when assumes that mu(phi) \propto exp(-2i phi).
732 """
733 self.parameters.update(self.get_sky_frame_parameters())
734 if signal_polarizations is None:
735 signal_polarizations = \
736 self.waveform_generator.frequency_domain_strain(self.parameters)
737 d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations)
739 phases = np.linspace(0, 2 * np.pi, 101)
740 phasor = np.exp(-2j * phases)
741 phase_log_post = d_inner_h * phasor - h_inner_h / 2
742 phase_post = np.exp(phase_log_post.real - max(phase_log_post.real))
743 new_phase = Interped(phases, phase_post).sample()
744 return new_phase
746 def distance_marginalized_likelihood(self, d_inner_h, h_inner_h):
747 d_inner_h_ref, h_inner_h_ref = self._setup_rho(
748 d_inner_h, h_inner_h)
749 if self.phase_marginalization:
750 d_inner_h_ref = np.abs(d_inner_h_ref)
751 else:
752 d_inner_h_ref = np.real(d_inner_h_ref)
754 return self._interp_dist_margd_loglikelihood(
755 d_inner_h_ref, h_inner_h_ref, grid=False)
757 def phase_marginalized_likelihood(self, d_inner_h, h_inner_h):
758 d_inner_h = ln_i0(abs(d_inner_h))
760 if self.calibration_marginalization and self.time_marginalization:
761 return d_inner_h - h_inner_h[:, np.newaxis] / 2
762 else:
763 return d_inner_h - h_inner_h / 2
765 def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h):
766 times = self._times
767 if self.jitter_time:
768 times = self._times + self.parameters['time_jitter']
770 _time_prior = self.priors['geocent_time']
771 time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum)
772 times = times[time_mask]
773 time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc
774 if self.calibration_marginalization:
775 d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask]
776 else:
777 d_inner_h_tc_array = d_inner_h_tc_array[time_mask]
779 if self.distance_marginalization:
780 log_l_tc_array = self.distance_marginalized_likelihood(
781 d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h)
782 elif self.phase_marginalization:
783 log_l_tc_array = self.phase_marginalized_likelihood(
784 d_inner_h=d_inner_h_tc_array,
785 h_inner_h=h_inner_h)
786 elif self.calibration_marginalization:
787 log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h[:, np.newaxis] / 2
788 else:
789 log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h / 2
790 return logsumexp(log_l_tc_array, b=time_prior_array, axis=-1)
792 def get_calibration_log_likelihoods(self, signal_polarizations=None):
793 self.parameters.update(self.get_sky_frame_parameters())
794 if signal_polarizations is None:
795 signal_polarizations = \
796 self.waveform_generator.frequency_domain_strain(self.parameters)
798 total_snrs = self._CalculatedSNRs()
800 for interferometer in self.interferometers:
801 per_detector_snr = self.calculate_snrs(
802 waveform_polarizations=signal_polarizations,
803 interferometer=interferometer)
805 total_snrs += per_detector_snr
807 if self.time_marginalization:
808 log_l_cal_array = self.time_marginalized_likelihood(
809 d_inner_h_tc_array=total_snrs.d_inner_h_array,
810 h_inner_h=total_snrs.optimal_snr_squared_array,
811 )
812 elif self.distance_marginalization:
813 log_l_cal_array = self.distance_marginalized_likelihood(
814 d_inner_h=total_snrs.d_inner_h_array,
815 h_inner_h=total_snrs.optimal_snr_squared_array)
816 elif self.phase_marginalization:
817 log_l_cal_array = self.phase_marginalized_likelihood(
818 d_inner_h=total_snrs.d_inner_h_array,
819 h_inner_h=total_snrs.optimal_snr_squared_array)
820 else:
821 log_l_cal_array = \
822 np.real(total_snrs.d_inner_h_array - total_snrs.optimal_snr_squared_array / 2)
824 return log_l_cal_array
826 def calibration_marginalized_likelihood(self, d_inner_h_calibration_array, h_inner_h):
827 if self.time_marginalization:
828 log_l_cal_array = self.time_marginalized_likelihood(
829 d_inner_h_tc_array=d_inner_h_calibration_array,
830 h_inner_h=h_inner_h,
831 )
832 elif self.distance_marginalization:
833 log_l_cal_array = self.distance_marginalized_likelihood(
834 d_inner_h=d_inner_h_calibration_array, h_inner_h=h_inner_h)
835 elif self.phase_marginalization:
836 log_l_cal_array = self.phase_marginalized_likelihood(
837 d_inner_h=d_inner_h_calibration_array,
838 h_inner_h=h_inner_h)
839 else:
840 log_l_cal_array = np.real(d_inner_h_calibration_array - h_inner_h / 2)
842 return logsumexp(log_l_cal_array) - np.log(self.number_of_response_curves)
844 def _setup_rho(self, d_inner_h, optimal_snr_squared):
845 optimal_snr_squared_ref = (optimal_snr_squared.real *
846 self.parameters['luminosity_distance'] ** 2 /
847 self._ref_dist ** 2.)
848 d_inner_h_ref = (d_inner_h * self.parameters['luminosity_distance'] /
849 self._ref_dist)
850 return d_inner_h_ref, optimal_snr_squared_ref
852 def log_likelihood(self):
853 return self.log_likelihood_ratio() + self.noise_log_likelihood()
855 @property
856 def _delta_distance(self):
857 return self._distance_array[1] - self._distance_array[0]
859 @property
860 def _dist_multiplier(self):
861 ''' Maximum value of ref_dist/dist_array '''
862 return self._ref_dist / self._distance_array[0]
864 @property
865 def _optimal_snr_squared_ref_array(self):
866 """ Optimal filter snr at fiducial distance of ref_dist Mpc """
867 return np.logspace(-5, 10, self._dist_margd_loglikelihood_array.shape[0])
869 @property
870 def _d_inner_h_ref_array(self):
871 """ Matched filter snr at fiducial distance of ref_dist Mpc """
872 if self.phase_marginalization:
873 return np.logspace(-5, 10, self._dist_margd_loglikelihood_array.shape[1])
874 else:
875 n_negative = self._dist_margd_loglikelihood_array.shape[1] // 2
876 n_positive = self._dist_margd_loglikelihood_array.shape[1] - n_negative
877 return np.hstack((
878 -np.logspace(3, -3, n_negative), np.logspace(-3, 10, n_positive)
879 ))
881 def _setup_distance_marginalization(self, lookup_table=None):
882 if isinstance(lookup_table, str) or lookup_table is None:
883 self.cached_lookup_table_filename = lookup_table
884 lookup_table = self.load_lookup_table(
885 self.cached_lookup_table_filename)
886 if isinstance(lookup_table, dict):
887 if self._test_cached_lookup_table(lookup_table):
888 self._dist_margd_loglikelihood_array = lookup_table[
889 'lookup_table']
890 else:
891 self._create_lookup_table()
892 else:
893 self._create_lookup_table()
894 self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline(
895 self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array,
896 self._dist_margd_loglikelihood_array.T, fill_value=-np.inf)
898 @property
899 def cached_lookup_table_filename(self):
900 if self._lookup_table_filename is None:
901 self._lookup_table_filename = (
902 '.distance_marginalization_lookup.npz')
903 return self._lookup_table_filename
905 @cached_lookup_table_filename.setter
906 def cached_lookup_table_filename(self, filename):
907 if isinstance(filename, str):
908 if filename[-4:] != '.npz':
909 filename += '.npz'
910 self._lookup_table_filename = filename
912 def load_lookup_table(self, filename):
913 if os.path.exists(filename):
914 try:
915 loaded_file = dict(np.load(filename))
916 except AttributeError as e:
917 logger.warning(e)
918 self._create_lookup_table()
919 return None
920 match, failure = self._test_cached_lookup_table(loaded_file)
921 if match:
922 logger.info('Loaded distance marginalisation lookup table from '
923 '{}.'.format(filename))
924 return loaded_file
925 else:
926 logger.info('Loaded distance marginalisation lookup table does '
927 'not match for {}.'.format(failure))
928 elif isinstance(filename, str):
929 logger.info('Distance marginalisation file {} does not '
930 'exist'.format(filename))
931 return None
933 def cache_lookup_table(self):
934 np.savez(self.cached_lookup_table_filename,
935 distance_array=self._distance_array,
936 prior_array=self.distance_prior_array,
937 lookup_table=self._dist_margd_loglikelihood_array,
938 reference_distance=self._ref_dist,
939 phase_marginalization=self.phase_marginalization)
941 def _test_cached_lookup_table(self, loaded_file):
942 pairs = dict(
943 distance_array=self._distance_array,
944 prior_array=self.distance_prior_array,
945 reference_distance=self._ref_dist,
946 phase_marginalization=self.phase_marginalization)
947 for key in pairs:
948 if key not in loaded_file:
949 return False, key
950 elif not np.allclose(np.atleast_1d(loaded_file[key]),
951 np.atleast_1d(pairs[key]),
952 rtol=1e-15):
953 return False, key
954 return True, None
956 def _create_lookup_table(self):
957 """ Make the lookup table """
958 from tqdm.auto import tqdm
959 logger.info('Building lookup table for distance marginalisation.')
961 self._dist_margd_loglikelihood_array = np.zeros((400, 800))
962 scaling = self._ref_dist / self._distance_array
963 d_inner_h_array_full = np.outer(self._d_inner_h_ref_array, scaling)
964 h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling ** 2)
965 if self.phase_marginalization:
966 d_inner_h_array_full = ln_i0(abs(d_inner_h_array_full))
967 prior_term = self.distance_prior_array * self._delta_distance
968 for ii, optimal_snr_squared_array in tqdm(
969 enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array)
970 ):
971 for jj, d_inner_h_array in enumerate(d_inner_h_array_full):
972 self._dist_margd_loglikelihood_array[ii][jj] = logsumexp(
973 d_inner_h_array - optimal_snr_squared_array / 2,
974 b=prior_term
975 )
976 log_norm = logsumexp(
977 0 / self._distance_array, b=self.distance_prior_array * self._delta_distance
978 )
979 self._dist_margd_loglikelihood_array -= log_norm
980 self.cache_lookup_table()
982 def _setup_phase_marginalization(self, min_bound=-5, max_bound=10):
983 logger.warning(
984 "The _setup_phase_marginalization method is deprecated and will be removed, "
985 "please update the implementation of phase marginalization "
986 "to use bilby.gw.utils.ln_i0"
987 )
989 def _setup_time_marginalization(self):
990 self._delta_tc = 2 / self.waveform_generator.sampling_frequency
991 self._times = \
992 self.interferometers.start_time + np.linspace(
993 0, self.interferometers.duration,
994 int(self.interferometers.duration / 2 *
995 self.waveform_generator.sampling_frequency + 1))[1:]
996 self.time_prior_array = \
997 self.priors['geocent_time'].prob(self._times) * self._delta_tc
999 def _setup_calibration_marginalization(self, calibration_lookup_table, priors=None):
1000 self.calibration_draws, self.calibration_parameter_draws = calibration.build_calibration_lookup(
1001 interferometers=self.interferometers,
1002 lookup_files=calibration_lookup_table,
1003 priors=priors,
1004 number_of_response_curves=self.number_of_response_curves,
1005 starting_index=self.starting_index,
1006 )
1007 for name, parameters in self.calibration_parameter_draws.items():
1008 if parameters is not None:
1009 for key in set(parameters.keys()).intersection(priors.keys()):
1010 priors[key] = DeltaFunction(0.0)
1011 self.calibration_abs_draws = dict()
1012 for name in self.calibration_draws:
1013 self.calibration_abs_draws[name] = np.abs(self.calibration_draws[name])**2
1015 @property
1016 def interferometers(self):
1017 return self._interferometers
1019 @interferometers.setter
1020 def interferometers(self, interferometers):
1021 self._interferometers = InterferometerList(interferometers)
1023 def _rescale_signal(self, signal, new_distance):
1024 for mode in signal:
1025 signal[mode] *= self._ref_dist / new_distance
1027 @property
1028 def reference_frame(self):
1029 return self._reference_frame
1031 @property
1032 def _reference_frame_str(self):
1033 if isinstance(self.reference_frame, str):
1034 return self.reference_frame
1035 else:
1036 return "".join([ifo.name for ifo in self.reference_frame])
1038 @reference_frame.setter
1039 def reference_frame(self, frame):
1040 if frame == "sky":
1041 self._reference_frame = frame
1042 elif isinstance(frame, InterferometerList):
1043 self._reference_frame = frame[:2]
1044 elif isinstance(frame, list):
1045 self._reference_frame = InterferometerList(frame[:2])
1046 elif isinstance(frame, str):
1047 self._reference_frame = InterferometerList([frame[:2], frame[2:4]])
1048 else:
1049 raise ValueError("Unable to parse reference frame {}".format(frame))
1051 def get_sky_frame_parameters(self, parameters=None):
1052 """
1053 Generate ra, dec, and geocenter time for :code:`parameters`
1055 This method will attempt to convert from the reference time and sky
1056 parameters, but if they are not present it will fall back to ra and dec.
1058 Parameters
1059 ==========
1060 parameters: dict, optional
1061 The parameters to be converted.
1062 If not specified :code:`self.parameters` will be used.
1064 Returns
1065 =======
1066 dict: dictionary containing ra, dec, and geocent_time
1067 """
1068 if parameters is None:
1069 parameters = self.parameters
1070 time = parameters.get(f'{self.time_reference}_time', None)
1071 if time is None and "geocent_time" in parameters:
1072 logger.warning(
1073 f"Cannot find {self.time_reference}_time in parameters. "
1074 "Falling back to geocent time"
1075 )
1076 if not self.reference_frame == "sky":
1077 try:
1078 ra, dec = zenith_azimuth_to_ra_dec(
1079 parameters['zenith'], parameters['azimuth'],
1080 time, self.reference_frame)
1081 except KeyError:
1082 if "ra" in parameters and "dec" in parameters:
1083 ra = parameters["ra"]
1084 dec = parameters["dec"]
1085 logger.warning(
1086 "Cannot convert from zenith/azimuth to ra/dec falling "
1087 "back to provided ra/dec"
1088 )
1089 else:
1090 raise
1091 else:
1092 ra = parameters["ra"]
1093 dec = parameters["dec"]
1094 if "geocent" not in self.time_reference and f"{self.time_reference}_time" in parameters:
1095 geocent_time = time - self.reference_ifo.time_delay_from_geocenter(
1096 ra=ra, dec=dec, time=time
1097 )
1098 else:
1099 geocent_time = parameters["geocent_time"]
1100 return dict(ra=ra, dec=dec, geocent_time=geocent_time)
1102 @property
1103 def lal_version(self):
1104 try:
1105 from lal import git_version, __version__
1106 lal_version = str(__version__)
1107 logger.info("Using lal version {}".format(lal_version))
1108 lal_git_version = str(git_version.verbose_msg).replace("\n", ";")
1109 logger.info("Using lal git version {}".format(lal_git_version))
1110 return "lal_version={}, lal_git_version={}".format(lal_version, lal_git_version)
1111 except (ImportError, AttributeError):
1112 return "N/A"
1114 @property
1115 def lalsimulation_version(self):
1116 try:
1117 from lalsimulation import git_version, __version__
1118 lalsim_version = str(__version__)
1119 logger.info("Using lalsimulation version {}".format(lalsim_version))
1120 lalsim_git_version = str(git_version.verbose_msg).replace("\n", ";")
1121 logger.info("Using lalsimulation git version {}".format(lalsim_git_version))
1122 return "lalsimulation_version={}, lalsimulation_git_version={}".format(lalsim_version, lalsim_git_version)
1123 except (ImportError, AttributeError):
1124 return "N/A"
1126 @property
1127 def meta_data(self):
1128 return dict(
1129 interferometers=self.interferometers.meta_data,
1130 time_marginalization=self.time_marginalization,
1131 phase_marginalization=self.phase_marginalization,
1132 distance_marginalization=self.distance_marginalization,
1133 calibration_marginalization=self.calibration_marginalization,
1134 waveform_generator_class=self.waveform_generator.__class__,
1135 waveform_arguments=self.waveform_generator.waveform_arguments,
1136 frequency_domain_source_model=self.waveform_generator.frequency_domain_source_model,
1137 time_domain_source_model=self.waveform_generator.time_domain_source_model,
1138 parameter_conversion=self.waveform_generator.parameter_conversion,
1139 sampling_frequency=self.waveform_generator.sampling_frequency,
1140 duration=self.waveform_generator.duration,
1141 start_time=self.waveform_generator.start_time,
1142 time_reference=self.time_reference,
1143 reference_frame=self._reference_frame_str,
1144 lal_version=self.lal_version,
1145 lalsimulation_version=self.lalsimulation_version)