Coverage for bilby/gw/likelihood/relative.py: 98%
209 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import numpy as np
2from scipy.optimize import differential_evolution
4from .base import GravitationalWaveTransient
5from ...core.utils import logger
6from ...core.prior.base import Constraint
7from ...core.prior import DeltaFunction
8from ..utils import noise_weighted_inner_product
11class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient):
12 """A gravitational-wave transient likelihood object which uses the relative
13 binning procedure to calculate a fast likelihood. See Zackay et al.
14 arXiv1806.08792
16 Parameters
17 ----------
18 interferometers: list, bilby.gw.detector.InterferometerList
19 A list of `bilby.detector.Interferometer` instances - contains the
20 detector data and power spectral densities
21 waveform_generator: `bilby.waveform_generator.WaveformGenerator`
22 An object which computes the frequency-domain strain of the signal,
23 given some set of parameters
24 fiducial_parameters: dict, optional
25 A starting guess for initial parameters of the event for finding the
26 maximum likelihood (fiducial) waveform. These should be specified in
27 the same parameter basis as the one that sampling is carried out in.
28 For example, if sampling in `mass_1` and `mass_2`, the fiducial
29 parameters should also be provided in `mass_1` and `mass_2.`
30 parameter_bounds: dict, optional
31 Dictionary of bounds (lists) for the initial parameters when finding
32 the initial maximum likelihood (fiducial) waveform.
33 distance_marginalization: bool, optional
34 If true, marginalize over distance in the likelihood.
35 This uses a look up table calculated at run time.
36 The distance prior is set to be a delta function at the minimum
37 distance allowed in the prior being marginalised over.
38 time_marginalization: bool, optional
39 If true, marginalize over time in the likelihood.
40 This uses a FFT to calculate the likelihood over a regularly spaced
41 grid.
42 In order to cover the whole space the prior is set to be uniform over
43 the spacing of the array of times.
44 If using time marginalisation and jitter_time is True a "jitter"
45 parameter is added to the prior which modifies the position of the
46 grid of times.
47 phase_marginalization: bool, optional
48 If true, marginalize over phase in the likelihood.
49 This is done analytically using a Bessel function.
50 The phase prior is set to be a delta function at phase=0.
51 priors: dict, optional
52 If given, used in the distance and phase marginalization.
53 distance_marginalization_lookup_table: (dict, str), optional
54 If a dict, dictionary containing the lookup_table, distance_array,
55 (distance) prior_array, and reference_distance used to construct
56 the table.
57 If a string the name of a file containing these quantities.
58 The lookup table is stored after construction in either the
59 provided string or a default location:
60 '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz'
61 jitter_time: bool, optional
62 Whether to introduce a `time_jitter` parameter. This avoids either
63 missing the likelihood peak, or introducing biases in the
64 reconstructed time posterior due to an insufficient sampling frequency.
65 Default is False, however using this parameter is strongly encouraged.
66 reference_frame: (str, bilby.gw.detector.InterferometerList, list), optional
67 Definition of the reference frame for the sky location.
68 - "sky": sample in RA/dec, this is the default
69 - e.g., "H1L1", ["H1", "L1"], InterferometerList(["H1", "L1"]):
70 sample in azimuth and zenith, `azimuth` and `zenith` defined in the
71 frame where the z-axis is aligned the the vector connecting H1
72 and L1.
73 time_reference: str, optional
74 Name of the reference for the sampled time parameter.
75 - "geocent"/"geocenter": sample in the time at the Earth's center,
76 this is the default
77 - e.g., "H1": sample in the time of arrival at H1
78 chi: float, optional
79 Tunable parameter which limits the perturbation of alpha when setting
80 up the bin range. See https://arxiv.org/abs/1806.08792.
81 epsilon: float, optional
82 Tunable parameter which limits the differential phase change in each
83 bin when setting up the bin range. See https://arxiv.org/abs/1806.08792.
85 Returns
86 -------
87 Likelihood: `bilby.core.likelihood.Likelihood`
88 A likelihood object, able to compute the likelihood of the data given
89 some model parameters.
91 Notes
92 -----
93 The relative binning likelihood does not currently support calibration marginalization.
94 """
96 def __init__(self, interferometers,
97 waveform_generator,
98 fiducial_parameters=None,
99 parameter_bounds=None,
100 maximization_kwargs=None,
101 update_fiducial_parameters=False,
102 distance_marginalization=False,
103 time_marginalization=False,
104 phase_marginalization=False,
105 priors=None,
106 distance_marginalization_lookup_table=None,
107 jitter_time=True,
108 reference_frame="sky",
109 time_reference="geocenter",
110 chi=1,
111 epsilon=0.5):
113 super(RelativeBinningGravitationalWaveTransient, self).__init__(
114 interferometers=interferometers,
115 waveform_generator=waveform_generator,
116 distance_marginalization=distance_marginalization,
117 phase_marginalization=phase_marginalization,
118 time_marginalization=time_marginalization,
119 priors=priors,
120 distance_marginalization_lookup_table=distance_marginalization_lookup_table,
121 jitter_time=jitter_time,
122 reference_frame=reference_frame,
123 time_reference=time_reference)
125 if fiducial_parameters is None:
126 logger.info("Drawing fiducial parameters from prior.")
127 fiducial_parameters = priors.sample()
128 self.fiducial_parameters = fiducial_parameters.copy()
129 self.fiducial_parameters["fiducial"] = 0
130 if self.time_marginalization:
131 self.fiducial_parameters["geocent_time"] = interferometers.start_time
132 if self.distance_marginalization:
133 self.fiducial_parameters["luminosity_distance"] = self._ref_dist
134 if self.phase_marginalization:
135 self.fiducial_parameters["phase"] = 0.0
136 self.chi = chi
137 self.epsilon = epsilon
138 self.gamma = np.array([-5 / 3, -2 / 3, 1, 5 / 3, 7 / 3])
139 self.maximum_frequency = waveform_generator.frequency_array[-1]
140 self.fiducial_waveform_obtained = False
141 self.check_if_bins_are_setup = False
142 self.fiducial_polarizations = None
143 self.per_detector_fiducial_waveforms = dict()
144 self.per_detector_fiducial_waveform_points = dict()
145 self.set_fiducial_waveforms(self.fiducial_parameters)
146 logger.info("Initial fiducial waveforms set up")
147 self.setup_bins()
148 self.compute_summary_data()
149 logger.info("Summary Data Obtained")
151 if update_fiducial_parameters:
152 # write a check to make sure prior is not None
153 logger.info("Using scipy optimization to find maximum likelihood parameters.")
154 self.parameters_to_be_updated = [key for key in priors if not isinstance(
155 priors[key], (DeltaFunction, Constraint, float, int))]
156 logger.info(f"Parameters over which likelihood is maximized: {self.parameters_to_be_updated}")
157 if parameter_bounds is None:
158 logger.info("No parameter bounds were given. Using priors instead.")
159 self.parameter_bounds = self.get_bounds_from_priors(priors)
160 else:
161 self.parameter_bounds = self.get_parameter_list_from_dictionary(parameter_bounds)
162 self.fiducial_parameters = self.find_maximum_likelihood_parameters(
163 self.parameter_bounds, maximization_kwargs=maximization_kwargs)
164 self.parameters.update(self.fiducial_parameters)
165 logger.info(f"Fiducial likelihood: {self.log_likelihood_ratio():.2f}")
166 self.parameters = dict(fiducial=0)
168 def __repr__(self):
169 return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\fiducial_parameters={},' \
170 .format(self.interferometers, self.waveform_generator, self.fiducial_parameters)
172 def setup_bins(self):
173 """
174 Setup the frequency bins following the method in
175 https://arxiv.org/abs/1806.08792.
177 If :code:`epsilon` is too small, the naive bins can be smaller than
178 the frequency spacing of the data. We require that bins are at least
179 as wide as this spacing.
180 """
181 frequency_array = self.waveform_generator.frequency_array
182 gamma = self.gamma[:, np.newaxis]
183 maximum_frequency = frequency_array[0]
184 minimum_frequency = frequency_array[-1]
185 for interferometer in self.interferometers:
186 maximum_frequency = max(maximum_frequency, interferometer.maximum_frequency)
187 minimum_frequency = min(minimum_frequency, interferometer.minimum_frequency)
188 maximum_frequency = min(maximum_frequency, self.maximum_frequency)
189 frequency_array_useful = frequency_array[
190 (frequency_array >= minimum_frequency)
191 & (frequency_array <= maximum_frequency)
192 ]
194 d_alpha = self.chi * 2 * np.pi / np.abs(
195 (minimum_frequency ** gamma) * np.heaviside(-gamma, 1)
196 - (maximum_frequency ** gamma) * np.heaviside(gamma, 1)
197 )
198 d_phi = np.sum(
199 np.sign(gamma) * d_alpha * frequency_array_useful ** gamma,
200 axis=0
201 )
202 d_phi_from_start = d_phi - d_phi[0]
203 number_of_bins = int(d_phi_from_start[-1] // self.epsilon)
204 bin_inds = list()
205 bin_freqs = list()
207 last_index = -1
208 for i in range(number_of_bins + 1):
209 bin_index = np.where(d_phi_from_start >= ((i / number_of_bins) * d_phi_from_start[-1]))[0][0]
210 if bin_index == last_index:
211 continue
212 bin_freq = frequency_array_useful[bin_index]
213 last_index = bin_index
214 bin_index = np.where(frequency_array >= bin_freq)[0][0]
215 bin_inds.append(bin_index)
216 bin_freqs.append(bin_freq)
217 self.bin_inds = np.array(bin_inds, dtype=int)
218 self.bin_sizes = np.diff(bin_inds)
219 self.bin_sizes[-1] += 1
220 self.bin_freqs = np.array(bin_freqs)
221 self.number_of_bins = len(self.bin_inds) - 1
222 logger.debug(
223 f"Set up {self.number_of_bins} bins "
224 f"between {minimum_frequency} Hz and {maximum_frequency} Hz"
225 )
226 self.waveform_generator.waveform_arguments["frequency_bin_edges"] = self.bin_freqs
227 self.bin_widths = self.bin_freqs[1:] - self.bin_freqs[:-1]
228 self.bin_centers = (self.bin_freqs[1:] + self.bin_freqs[:-1]) / 2
229 for interferometer in self.interferometers:
230 name = interferometer.name
231 self.per_detector_fiducial_waveform_points[name] = (
232 self.per_detector_fiducial_waveforms[name][self.bin_inds]
233 )
235 def set_fiducial_waveforms(self, parameters):
236 parameters = parameters.copy()
237 parameters["fiducial"] = 1
238 parameters.update(self.get_sky_frame_parameters(parameters=parameters))
239 self.fiducial_polarizations = self.waveform_generator.frequency_domain_strain(
240 parameters)
242 maximum_nonzero_index = np.where(self.fiducial_polarizations["plus"] != 0j)[0][-1]
243 logger.debug(f"Maximum Nonzero Index is {maximum_nonzero_index}")
244 maximum_nonzero_frequency = self.waveform_generator.frequency_array[maximum_nonzero_index]
245 logger.debug(f"Maximum Nonzero Frequency is {maximum_nonzero_frequency}")
246 self.maximum_frequency = maximum_nonzero_frequency
248 if self.fiducial_polarizations is None:
249 raise ValueError(f"Cannot compute fiducial waveforms for {parameters}")
251 for interferometer in self.interferometers:
252 logger.debug(f"Maximum Frequency is {interferometer.maximum_frequency}")
253 wf = interferometer.get_detector_response(self.fiducial_polarizations, parameters)
254 wf[interferometer.frequency_array > self.maximum_frequency] = 0
255 self.per_detector_fiducial_waveforms[interferometer.name] = wf
257 def find_maximum_likelihood_parameters(self, parameter_bounds,
258 iterations=5, maximization_kwargs=None):
259 if maximization_kwargs is None:
260 maximization_kwargs = dict()
261 self.parameters.update(self.fiducial_parameters)
262 self.parameters["fiducial"] = 0
263 updated_parameters_list = self.get_parameter_list_from_dictionary(self.fiducial_parameters)
264 old_fiducial_ln_likelihood = self.log_likelihood_ratio()
265 logger.info(f"Fiducial ln likelihood ratio: {old_fiducial_ln_likelihood:.2f}")
266 for it in range(iterations):
267 logger.info(f"Optimizing fiducial parameters. Iteration : {it + 1}")
268 output = differential_evolution(
269 self.lnlike_scipy_maximize,
270 bounds=parameter_bounds,
271 x0=updated_parameters_list,
272 **maximization_kwargs,
273 )
274 updated_parameters_list = output['x']
275 updated_parameters = self.get_parameter_dictionary_from_list(updated_parameters_list)
276 self.parameters.update(updated_parameters)
277 self.set_fiducial_waveforms(updated_parameters)
278 self.setup_bins()
279 self.compute_summary_data()
280 new_fiducial_ln_likelihood = self.log_likelihood_ratio()
281 logger.info(f"Fiducial ln likelihood ratio: {new_fiducial_ln_likelihood:.2f}")
282 if new_fiducial_ln_likelihood - old_fiducial_ln_likelihood < 0.1:
283 break
284 old_fiducial_ln_likelihood = new_fiducial_ln_likelihood
286 logger.info("Fiducial waveforms updated")
287 logger.info("Summary Data updated")
288 return updated_parameters
290 def lnlike_scipy_maximize(self, parameter_list):
291 self.parameters.update(self.get_parameter_dictionary_from_list(parameter_list))
292 return -self.log_likelihood_ratio()
294 def get_parameter_dictionary_from_list(self, parameter_list):
295 parameter_dictionary = dict(zip(self.parameters_to_be_updated, parameter_list))
296 excluded_parameter_keys = set(self.fiducial_parameters) - set(self.parameters_to_be_updated)
297 for key in excluded_parameter_keys:
298 parameter_dictionary[key] = self.fiducial_parameters[key]
299 return parameter_dictionary
301 def get_parameter_list_from_dictionary(self, parameter_dict):
302 return [parameter_dict[k] for k in self.parameters_to_be_updated]
304 def get_bounds_from_priors(self, priors):
305 bounds = []
306 for key in self.parameters_to_be_updated:
307 bounds.append([priors[key].minimum, priors[key].maximum])
308 return bounds
310 def compute_summary_data(self):
311 summary_data = dict()
313 for interferometer in self.interferometers:
314 mask = interferometer.frequency_mask
315 masked_frequency_array = interferometer.frequency_array[mask]
316 masked_bin_inds = []
317 for edge in self.bin_freqs:
318 index = np.where(masked_frequency_array == edge)[0][0]
319 masked_bin_inds.append(index)
320 # For the last bin, make sure to include
321 # the last point in the frequency array
322 masked_bin_inds[-1] += 1
324 masked_strain = interferometer.frequency_domain_strain[mask]
325 masked_h0 = self.per_detector_fiducial_waveforms[interferometer.name][mask]
326 masked_psd = interferometer.power_spectral_density_array[mask]
327 duration = interferometer.duration
328 a0, b0, a1, b1 = np.zeros((4, self.number_of_bins), dtype=complex)
330 for i in range(self.number_of_bins):
331 start_idx = masked_bin_inds[i]
332 end_idx = masked_bin_inds[i + 1]
333 start = masked_frequency_array[start_idx]
334 stop = masked_frequency_array[end_idx]
335 idxs = slice(start_idx, end_idx)
337 strain = masked_strain[idxs]
338 h0 = masked_h0[idxs]
339 psd = masked_psd[idxs]
341 frequencies = masked_frequency_array[idxs]
342 central_frequency = (start + stop) / 2
343 delta_frequency = frequencies - central_frequency
345 a0[i] = noise_weighted_inner_product(h0, strain, psd, duration)
346 b0[i] = noise_weighted_inner_product(h0, h0, psd, duration)
347 a1[i] = noise_weighted_inner_product(h0, strain * delta_frequency, psd, duration)
348 b1[i] = noise_weighted_inner_product(h0, h0 * delta_frequency, psd, duration)
350 summary_data[interferometer.name] = (a0, a1, b0, b1)
352 self.summary_data = summary_data
354 def compute_waveform_ratio_per_interferometer(self, waveform_polarizations, interferometer):
355 name = interferometer.name
356 strain = interferometer.get_detector_response(
357 waveform_polarizations=waveform_polarizations,
358 parameters=self.parameters,
359 frequencies=self.bin_freqs,
360 )
361 reference_strain = self.per_detector_fiducial_waveform_points[name]
362 waveform_ratio = strain / reference_strain
364 r0 = (waveform_ratio[1:] + waveform_ratio[:-1]) / 2
365 r1 = (waveform_ratio[1:] - waveform_ratio[:-1]) / self.bin_widths
367 return [r0, r1]
369 def _compute_full_waveform(self, signal_polarizations, interferometer):
370 fiducial_waveform = self.per_detector_fiducial_waveforms[interferometer.name]
371 r0, r1 = self.compute_waveform_ratio_per_interferometer(
372 waveform_polarizations=signal_polarizations,
373 interferometer=interferometer,
374 )
376 idxs = slice(self.bin_inds[0], self.bin_inds[-1] + 1)
377 duplicated_r0 = np.repeat(r0, self.bin_sizes)
378 duplicated_r1 = np.repeat(r1, self.bin_sizes)
379 duplicated_fm = np.repeat(self.bin_centers, self.bin_sizes)
381 f = interferometer.frequency_array
382 full_waveform_ratio = np.zeros(f.shape[0], dtype=complex)
383 full_waveform_ratio[idxs] = duplicated_r0 + duplicated_r1 * (f[idxs] - duplicated_fm)
384 return fiducial_waveform * full_waveform_ratio
386 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True):
387 r0, r1 = self.compute_waveform_ratio_per_interferometer(
388 waveform_polarizations=waveform_polarizations,
389 interferometer=interferometer,
390 )
391 a0, a1, b0, b1 = self.summary_data[interferometer.name]
392 d_inner_h = np.sum(a0 * np.conjugate(r0) + a1 * np.conjugate(r1))
393 h_inner_h = np.sum(b0 * np.abs(r0) ** 2 + 2 * b1 * np.real(r0 * np.conjugate(r1)))
394 optimal_snr_squared = h_inner_h
395 complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5)
397 if return_array and self.time_marginalization:
398 full_waveform = self._compute_full_waveform(
399 signal_polarizations=waveform_polarizations,
400 interferometer=interferometer,
401 )
402 d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft(
403 full_waveform[0:-1]
404 * interferometer.frequency_domain_strain.conjugate()[0:-1]
405 / interferometer.power_spectral_density_array[0:-1])
407 else:
408 d_inner_h_array = None
410 return self._CalculatedSNRs(
411 d_inner_h=d_inner_h,
412 optimal_snr_squared=optimal_snr_squared.real,
413 complex_matched_filter_snr=complex_matched_filter_snr,
414 d_inner_h_array=d_inner_h_array
415 )