Coverage for bilby/gw/likelihood/relative.py: 98%

209 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import numpy as np 

2from scipy.optimize import differential_evolution 

3 

4from .base import GravitationalWaveTransient 

5from ...core.utils import logger 

6from ...core.prior.base import Constraint 

7from ...core.prior import DeltaFunction 

8from ..utils import noise_weighted_inner_product 

9 

10 

11class RelativeBinningGravitationalWaveTransient(GravitationalWaveTransient): 

12 """A gravitational-wave transient likelihood object which uses the relative 

13 binning procedure to calculate a fast likelihood. See Zackay et al. 

14 arXiv1806.08792 

15 

16 Parameters 

17 ---------- 

18 interferometers: list, bilby.gw.detector.InterferometerList 

19 A list of `bilby.detector.Interferometer` instances - contains the 

20 detector data and power spectral densities 

21 waveform_generator: `bilby.waveform_generator.WaveformGenerator` 

22 An object which computes the frequency-domain strain of the signal, 

23 given some set of parameters 

24 fiducial_parameters: dict, optional 

25 A starting guess for initial parameters of the event for finding the 

26 maximum likelihood (fiducial) waveform. These should be specified in 

27 the same parameter basis as the one that sampling is carried out in. 

28 For example, if sampling in `mass_1` and `mass_2`, the fiducial 

29 parameters should also be provided in `mass_1` and `mass_2.` 

30 parameter_bounds: dict, optional 

31 Dictionary of bounds (lists) for the initial parameters when finding 

32 the initial maximum likelihood (fiducial) waveform. 

33 distance_marginalization: bool, optional 

34 If true, marginalize over distance in the likelihood. 

35 This uses a look up table calculated at run time. 

36 The distance prior is set to be a delta function at the minimum 

37 distance allowed in the prior being marginalised over. 

38 time_marginalization: bool, optional 

39 If true, marginalize over time in the likelihood. 

40 This uses a FFT to calculate the likelihood over a regularly spaced 

41 grid. 

42 In order to cover the whole space the prior is set to be uniform over 

43 the spacing of the array of times. 

44 If using time marginalisation and jitter_time is True a "jitter" 

45 parameter is added to the prior which modifies the position of the 

46 grid of times. 

47 phase_marginalization: bool, optional 

48 If true, marginalize over phase in the likelihood. 

49 This is done analytically using a Bessel function. 

50 The phase prior is set to be a delta function at phase=0. 

51 priors: dict, optional 

52 If given, used in the distance and phase marginalization. 

53 distance_marginalization_lookup_table: (dict, str), optional 

54 If a dict, dictionary containing the lookup_table, distance_array, 

55 (distance) prior_array, and reference_distance used to construct 

56 the table. 

57 If a string the name of a file containing these quantities. 

58 The lookup table is stored after construction in either the 

59 provided string or a default location: 

60 '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz' 

61 jitter_time: bool, optional 

62 Whether to introduce a `time_jitter` parameter. This avoids either 

63 missing the likelihood peak, or introducing biases in the 

64 reconstructed time posterior due to an insufficient sampling frequency. 

65 Default is False, however using this parameter is strongly encouraged. 

66 reference_frame: (str, bilby.gw.detector.InterferometerList, list), optional 

67 Definition of the reference frame for the sky location. 

68 - "sky": sample in RA/dec, this is the default 

69 - e.g., "H1L1", ["H1", "L1"], InterferometerList(["H1", "L1"]): 

70 sample in azimuth and zenith, `azimuth` and `zenith` defined in the 

71 frame where the z-axis is aligned the the vector connecting H1 

72 and L1. 

73 time_reference: str, optional 

74 Name of the reference for the sampled time parameter. 

75 - "geocent"/"geocenter": sample in the time at the Earth's center, 

76 this is the default 

77 - e.g., "H1": sample in the time of arrival at H1 

78 chi: float, optional 

79 Tunable parameter which limits the perturbation of alpha when setting 

80 up the bin range. See https://arxiv.org/abs/1806.08792. 

81 epsilon: float, optional 

82 Tunable parameter which limits the differential phase change in each 

83 bin when setting up the bin range. See https://arxiv.org/abs/1806.08792. 

84 

85 Returns 

86 ------- 

87 Likelihood: `bilby.core.likelihood.Likelihood` 

88 A likelihood object, able to compute the likelihood of the data given 

89 some model parameters. 

90 

91 Notes 

92 ----- 

93 The relative binning likelihood does not currently support calibration marginalization. 

94 """ 

95 

96 def __init__(self, interferometers, 

97 waveform_generator, 

98 fiducial_parameters=None, 

99 parameter_bounds=None, 

100 maximization_kwargs=None, 

101 update_fiducial_parameters=False, 

102 distance_marginalization=False, 

103 time_marginalization=False, 

104 phase_marginalization=False, 

105 priors=None, 

106 distance_marginalization_lookup_table=None, 

107 jitter_time=True, 

108 reference_frame="sky", 

109 time_reference="geocenter", 

110 chi=1, 

111 epsilon=0.5): 

112 

113 super(RelativeBinningGravitationalWaveTransient, self).__init__( 

114 interferometers=interferometers, 

115 waveform_generator=waveform_generator, 

116 distance_marginalization=distance_marginalization, 

117 phase_marginalization=phase_marginalization, 

118 time_marginalization=time_marginalization, 

119 priors=priors, 

120 distance_marginalization_lookup_table=distance_marginalization_lookup_table, 

121 jitter_time=jitter_time, 

122 reference_frame=reference_frame, 

123 time_reference=time_reference) 

124 

125 if fiducial_parameters is None: 

126 logger.info("Drawing fiducial parameters from prior.") 

127 fiducial_parameters = priors.sample() 

128 self.fiducial_parameters = fiducial_parameters.copy() 

129 self.fiducial_parameters["fiducial"] = 0 

130 if self.time_marginalization: 

131 self.fiducial_parameters["geocent_time"] = interferometers.start_time 

132 if self.distance_marginalization: 

133 self.fiducial_parameters["luminosity_distance"] = self._ref_dist 

134 if self.phase_marginalization: 

135 self.fiducial_parameters["phase"] = 0.0 

136 self.chi = chi 

137 self.epsilon = epsilon 

138 self.gamma = np.array([-5 / 3, -2 / 3, 1, 5 / 3, 7 / 3]) 

139 self.maximum_frequency = waveform_generator.frequency_array[-1] 

140 self.fiducial_waveform_obtained = False 

141 self.check_if_bins_are_setup = False 

142 self.fiducial_polarizations = None 

143 self.per_detector_fiducial_waveforms = dict() 

144 self.per_detector_fiducial_waveform_points = dict() 

145 self.set_fiducial_waveforms(self.fiducial_parameters) 

146 logger.info("Initial fiducial waveforms set up") 

147 self.setup_bins() 

148 self.compute_summary_data() 

149 logger.info("Summary Data Obtained") 

150 

151 if update_fiducial_parameters: 

152 # write a check to make sure prior is not None 

153 logger.info("Using scipy optimization to find maximum likelihood parameters.") 

154 self.parameters_to_be_updated = [key for key in priors if not isinstance( 

155 priors[key], (DeltaFunction, Constraint, float, int))] 

156 logger.info(f"Parameters over which likelihood is maximized: {self.parameters_to_be_updated}") 

157 if parameter_bounds is None: 

158 logger.info("No parameter bounds were given. Using priors instead.") 

159 self.parameter_bounds = self.get_bounds_from_priors(priors) 

160 else: 

161 self.parameter_bounds = self.get_parameter_list_from_dictionary(parameter_bounds) 

162 self.fiducial_parameters = self.find_maximum_likelihood_parameters( 

163 self.parameter_bounds, maximization_kwargs=maximization_kwargs) 

164 self.parameters.update(self.fiducial_parameters) 

165 logger.info(f"Fiducial likelihood: {self.log_likelihood_ratio():.2f}") 

166 self.parameters = dict(fiducial=0) 

167 

168 def __repr__(self): 

169 return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\fiducial_parameters={},' \ 

170 .format(self.interferometers, self.waveform_generator, self.fiducial_parameters) 

171 

172 def setup_bins(self): 

173 """ 

174 Setup the frequency bins following the method in 

175 https://arxiv.org/abs/1806.08792. 

176 

177 If :code:`epsilon` is too small, the naive bins can be smaller than 

178 the frequency spacing of the data. We require that bins are at least 

179 as wide as this spacing. 

180 """ 

181 frequency_array = self.waveform_generator.frequency_array 

182 gamma = self.gamma[:, np.newaxis] 

183 maximum_frequency = frequency_array[0] 

184 minimum_frequency = frequency_array[-1] 

185 for interferometer in self.interferometers: 

186 maximum_frequency = max(maximum_frequency, interferometer.maximum_frequency) 

187 minimum_frequency = min(minimum_frequency, interferometer.minimum_frequency) 

188 maximum_frequency = min(maximum_frequency, self.maximum_frequency) 

189 frequency_array_useful = frequency_array[ 

190 (frequency_array >= minimum_frequency) 

191 & (frequency_array <= maximum_frequency) 

192 ] 

193 

194 d_alpha = self.chi * 2 * np.pi / np.abs( 

195 (minimum_frequency ** gamma) * np.heaviside(-gamma, 1) 

196 - (maximum_frequency ** gamma) * np.heaviside(gamma, 1) 

197 ) 

198 d_phi = np.sum( 

199 np.sign(gamma) * d_alpha * frequency_array_useful ** gamma, 

200 axis=0 

201 ) 

202 d_phi_from_start = d_phi - d_phi[0] 

203 number_of_bins = int(d_phi_from_start[-1] // self.epsilon) 

204 bin_inds = list() 

205 bin_freqs = list() 

206 

207 last_index = -1 

208 for i in range(number_of_bins + 1): 

209 bin_index = np.where(d_phi_from_start >= ((i / number_of_bins) * d_phi_from_start[-1]))[0][0] 

210 if bin_index == last_index: 

211 continue 

212 bin_freq = frequency_array_useful[bin_index] 

213 last_index = bin_index 

214 bin_index = np.where(frequency_array >= bin_freq)[0][0] 

215 bin_inds.append(bin_index) 

216 bin_freqs.append(bin_freq) 

217 self.bin_inds = np.array(bin_inds, dtype=int) 

218 self.bin_sizes = np.diff(bin_inds) 

219 self.bin_sizes[-1] += 1 

220 self.bin_freqs = np.array(bin_freqs) 

221 self.number_of_bins = len(self.bin_inds) - 1 

222 logger.debug( 

223 f"Set up {self.number_of_bins} bins " 

224 f"between {minimum_frequency} Hz and {maximum_frequency} Hz" 

225 ) 

226 self.waveform_generator.waveform_arguments["frequency_bin_edges"] = self.bin_freqs 

227 self.bin_widths = self.bin_freqs[1:] - self.bin_freqs[:-1] 

228 self.bin_centers = (self.bin_freqs[1:] + self.bin_freqs[:-1]) / 2 

229 for interferometer in self.interferometers: 

230 name = interferometer.name 

231 self.per_detector_fiducial_waveform_points[name] = ( 

232 self.per_detector_fiducial_waveforms[name][self.bin_inds] 

233 ) 

234 

235 def set_fiducial_waveforms(self, parameters): 

236 parameters = parameters.copy() 

237 parameters["fiducial"] = 1 

238 parameters.update(self.get_sky_frame_parameters(parameters=parameters)) 

239 self.fiducial_polarizations = self.waveform_generator.frequency_domain_strain( 

240 parameters) 

241 

242 maximum_nonzero_index = np.where(self.fiducial_polarizations["plus"] != 0j)[0][-1] 

243 logger.debug(f"Maximum Nonzero Index is {maximum_nonzero_index}") 

244 maximum_nonzero_frequency = self.waveform_generator.frequency_array[maximum_nonzero_index] 

245 logger.debug(f"Maximum Nonzero Frequency is {maximum_nonzero_frequency}") 

246 self.maximum_frequency = maximum_nonzero_frequency 

247 

248 if self.fiducial_polarizations is None: 

249 raise ValueError(f"Cannot compute fiducial waveforms for {parameters}") 

250 

251 for interferometer in self.interferometers: 

252 logger.debug(f"Maximum Frequency is {interferometer.maximum_frequency}") 

253 wf = interferometer.get_detector_response(self.fiducial_polarizations, parameters) 

254 wf[interferometer.frequency_array > self.maximum_frequency] = 0 

255 self.per_detector_fiducial_waveforms[interferometer.name] = wf 

256 

257 def find_maximum_likelihood_parameters(self, parameter_bounds, 

258 iterations=5, maximization_kwargs=None): 

259 if maximization_kwargs is None: 

260 maximization_kwargs = dict() 

261 self.parameters.update(self.fiducial_parameters) 

262 self.parameters["fiducial"] = 0 

263 updated_parameters_list = self.get_parameter_list_from_dictionary(self.fiducial_parameters) 

264 old_fiducial_ln_likelihood = self.log_likelihood_ratio() 

265 logger.info(f"Fiducial ln likelihood ratio: {old_fiducial_ln_likelihood:.2f}") 

266 for it in range(iterations): 

267 logger.info(f"Optimizing fiducial parameters. Iteration : {it + 1}") 

268 output = differential_evolution( 

269 self.lnlike_scipy_maximize, 

270 bounds=parameter_bounds, 

271 x0=updated_parameters_list, 

272 **maximization_kwargs, 

273 ) 

274 updated_parameters_list = output['x'] 

275 updated_parameters = self.get_parameter_dictionary_from_list(updated_parameters_list) 

276 self.parameters.update(updated_parameters) 

277 self.set_fiducial_waveforms(updated_parameters) 

278 self.setup_bins() 

279 self.compute_summary_data() 

280 new_fiducial_ln_likelihood = self.log_likelihood_ratio() 

281 logger.info(f"Fiducial ln likelihood ratio: {new_fiducial_ln_likelihood:.2f}") 

282 if new_fiducial_ln_likelihood - old_fiducial_ln_likelihood < 0.1: 

283 break 

284 old_fiducial_ln_likelihood = new_fiducial_ln_likelihood 

285 

286 logger.info("Fiducial waveforms updated") 

287 logger.info("Summary Data updated") 

288 return updated_parameters 

289 

290 def lnlike_scipy_maximize(self, parameter_list): 

291 self.parameters.update(self.get_parameter_dictionary_from_list(parameter_list)) 

292 return -self.log_likelihood_ratio() 

293 

294 def get_parameter_dictionary_from_list(self, parameter_list): 

295 parameter_dictionary = dict(zip(self.parameters_to_be_updated, parameter_list)) 

296 excluded_parameter_keys = set(self.fiducial_parameters) - set(self.parameters_to_be_updated) 

297 for key in excluded_parameter_keys: 

298 parameter_dictionary[key] = self.fiducial_parameters[key] 

299 return parameter_dictionary 

300 

301 def get_parameter_list_from_dictionary(self, parameter_dict): 

302 return [parameter_dict[k] for k in self.parameters_to_be_updated] 

303 

304 def get_bounds_from_priors(self, priors): 

305 bounds = [] 

306 for key in self.parameters_to_be_updated: 

307 bounds.append([priors[key].minimum, priors[key].maximum]) 

308 return bounds 

309 

310 def compute_summary_data(self): 

311 summary_data = dict() 

312 

313 for interferometer in self.interferometers: 

314 mask = interferometer.frequency_mask 

315 masked_frequency_array = interferometer.frequency_array[mask] 

316 masked_bin_inds = [] 

317 for edge in self.bin_freqs: 

318 index = np.where(masked_frequency_array == edge)[0][0] 

319 masked_bin_inds.append(index) 

320 # For the last bin, make sure to include 

321 # the last point in the frequency array 

322 masked_bin_inds[-1] += 1 

323 

324 masked_strain = interferometer.frequency_domain_strain[mask] 

325 masked_h0 = self.per_detector_fiducial_waveforms[interferometer.name][mask] 

326 masked_psd = interferometer.power_spectral_density_array[mask] 

327 duration = interferometer.duration 

328 a0, b0, a1, b1 = np.zeros((4, self.number_of_bins), dtype=complex) 

329 

330 for i in range(self.number_of_bins): 

331 start_idx = masked_bin_inds[i] 

332 end_idx = masked_bin_inds[i + 1] 

333 start = masked_frequency_array[start_idx] 

334 stop = masked_frequency_array[end_idx] 

335 idxs = slice(start_idx, end_idx) 

336 

337 strain = masked_strain[idxs] 

338 h0 = masked_h0[idxs] 

339 psd = masked_psd[idxs] 

340 

341 frequencies = masked_frequency_array[idxs] 

342 central_frequency = (start + stop) / 2 

343 delta_frequency = frequencies - central_frequency 

344 

345 a0[i] = noise_weighted_inner_product(h0, strain, psd, duration) 

346 b0[i] = noise_weighted_inner_product(h0, h0, psd, duration) 

347 a1[i] = noise_weighted_inner_product(h0, strain * delta_frequency, psd, duration) 

348 b1[i] = noise_weighted_inner_product(h0, h0 * delta_frequency, psd, duration) 

349 

350 summary_data[interferometer.name] = (a0, a1, b0, b1) 

351 

352 self.summary_data = summary_data 

353 

354 def compute_waveform_ratio_per_interferometer(self, waveform_polarizations, interferometer): 

355 name = interferometer.name 

356 strain = interferometer.get_detector_response( 

357 waveform_polarizations=waveform_polarizations, 

358 parameters=self.parameters, 

359 frequencies=self.bin_freqs, 

360 ) 

361 reference_strain = self.per_detector_fiducial_waveform_points[name] 

362 waveform_ratio = strain / reference_strain 

363 

364 r0 = (waveform_ratio[1:] + waveform_ratio[:-1]) / 2 

365 r1 = (waveform_ratio[1:] - waveform_ratio[:-1]) / self.bin_widths 

366 

367 return [r0, r1] 

368 

369 def _compute_full_waveform(self, signal_polarizations, interferometer): 

370 fiducial_waveform = self.per_detector_fiducial_waveforms[interferometer.name] 

371 r0, r1 = self.compute_waveform_ratio_per_interferometer( 

372 waveform_polarizations=signal_polarizations, 

373 interferometer=interferometer, 

374 ) 

375 

376 idxs = slice(self.bin_inds[0], self.bin_inds[-1] + 1) 

377 duplicated_r0 = np.repeat(r0, self.bin_sizes) 

378 duplicated_r1 = np.repeat(r1, self.bin_sizes) 

379 duplicated_fm = np.repeat(self.bin_centers, self.bin_sizes) 

380 

381 f = interferometer.frequency_array 

382 full_waveform_ratio = np.zeros(f.shape[0], dtype=complex) 

383 full_waveform_ratio[idxs] = duplicated_r0 + duplicated_r1 * (f[idxs] - duplicated_fm) 

384 return fiducial_waveform * full_waveform_ratio 

385 

386 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True): 

387 r0, r1 = self.compute_waveform_ratio_per_interferometer( 

388 waveform_polarizations=waveform_polarizations, 

389 interferometer=interferometer, 

390 ) 

391 a0, a1, b0, b1 = self.summary_data[interferometer.name] 

392 d_inner_h = np.sum(a0 * np.conjugate(r0) + a1 * np.conjugate(r1)) 

393 h_inner_h = np.sum(b0 * np.abs(r0) ** 2 + 2 * b1 * np.real(r0 * np.conjugate(r1))) 

394 optimal_snr_squared = h_inner_h 

395 complex_matched_filter_snr = d_inner_h / (optimal_snr_squared ** 0.5) 

396 

397 if return_array and self.time_marginalization: 

398 full_waveform = self._compute_full_waveform( 

399 signal_polarizations=waveform_polarizations, 

400 interferometer=interferometer, 

401 ) 

402 d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( 

403 full_waveform[0:-1] 

404 * interferometer.frequency_domain_strain.conjugate()[0:-1] 

405 / interferometer.power_spectral_density_array[0:-1]) 

406 

407 else: 

408 d_inner_h_array = None 

409 

410 return self._CalculatedSNRs( 

411 d_inner_h=d_inner_h, 

412 optimal_snr_squared=optimal_snr_squared.real, 

413 complex_matched_filter_snr=complex_matched_filter_snr, 

414 d_inner_h_array=d_inner_h_array 

415 )