Coverage for bilby/gw/likelihood/base.py: 86%

534 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1 

2import os 

3import copy 

4 

5import attr 

6import numpy as np 

7from scipy.special import logsumexp 

8 

9from ...core.likelihood import Likelihood 

10from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series 

11from ...core.prior import Interped, Prior, Uniform, DeltaFunction 

12from ..detector import InterferometerList, get_empty_interferometer, calibration 

13from ..prior import BBHPriorDict, Cosmological 

14from ..utils import noise_weighted_inner_product, zenith_azimuth_to_ra_dec, ln_i0 

15 

16 

17class GravitationalWaveTransient(Likelihood): 

18 """ A gravitational-wave transient likelihood object 

19 

20 This is the usual likelihood object to use for transient gravitational 

21 wave parameter estimation. It computes the log-likelihood in the frequency 

22 domain assuming a colored Gaussian noise model described by a power 

23 spectral density. See Thrane & Talbot (2019), arxiv.org/abs/1809.02293. 

24 

25 Parameters 

26 ========== 

27 interferometers: list, bilby.gw.detector.InterferometerList 

28 A list of `bilby.detector.Interferometer` instances - contains the 

29 detector data and power spectral densities 

30 waveform_generator: `bilby.waveform_generator.WaveformGenerator` 

31 An object which computes the frequency-domain strain of the signal, 

32 given some set of parameters 

33 distance_marginalization: bool, optional 

34 If true, marginalize over distance in the likelihood. 

35 This uses a look up table calculated at run time. 

36 The distance prior is set to be a delta function at the minimum 

37 distance allowed in the prior being marginalised over. 

38 time_marginalization: bool, optional 

39 If true, marginalize over time in the likelihood. 

40 This uses a FFT to calculate the likelihood over a regularly spaced 

41 grid. 

42 In order to cover the whole space the prior is set to be uniform over 

43 the spacing of the array of times. 

44 If using time marginalisation and jitter_time is True a "jitter" 

45 parameter is added to the prior which modifies the position of the 

46 grid of times. 

47 phase_marginalization: bool, optional 

48 If true, marginalize over phase in the likelihood. 

49 This is done analytically using a Bessel function. 

50 The phase prior is set to be a delta function at phase=0. 

51 calibration_marginalization: bool, optional 

52 If true, marginalize over calibration response curves in the likelihood. 

53 This is done numerically over a number of calibration response curve realizations. 

54 priors: dict, optional 

55 If given, used in the distance and phase marginalization. 

56 Warning: when using marginalisation the dict is overwritten which will change the 

57 the dict you are passing in. If this behaviour is undesired, pass `priors.copy()`. 

58 distance_marginalization_lookup_table: (dict, str), optional 

59 If a dict, dictionary containing the lookup_table, distance_array, 

60 (distance) prior_array, and reference_distance used to construct 

61 the table. 

62 If a string the name of a file containing these quantities. 

63 The lookup table is stored after construction in either the 

64 provided string or a default location: 

65 '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz' 

66 calibration_lookup_table: dict, optional 

67 If a dict, contains the arrays over which to marginalize for each interferometer or the filepaths of the 

68 calibration files. 

69 If not provided, but calibration_marginalization is used, then the appropriate file is created to 

70 contain the curves. 

71 number_of_response_curves: int, optional 

72 Number of curves from the calibration lookup table to use. 

73 Default is 1000. 

74 starting_index: int, optional 

75 Sets the index for the first realization of the calibration curve to be considered. 

76 This, coupled with number_of_response_curves, allows for restricting the set of curves used. This can be used 

77 when dealing with large frequency arrays to split the calculation into sections. 

78 Defaults to 0. 

79 jitter_time: bool, optional 

80 Whether to introduce a `time_jitter` parameter. This avoids either 

81 missing the likelihood peak, or introducing biases in the 

82 reconstructed time posterior due to an insufficient sampling frequency. 

83 Default is False, however using this parameter is strongly encouraged. 

84 reference_frame: (str, bilby.gw.detector.InterferometerList, list), optional 

85 Definition of the reference frame for the sky location. 

86 

87 - :code:`sky`: sample in RA/dec, this is the default 

88 - e.g., :code:`"H1L1", ["H1", "L1"], InterferometerList(["H1", "L1"])`: 

89 sample in azimuth and zenith, `azimuth` and `zenith` defined in the 

90 frame where the z-axis is aligned the the vector connecting H1 

91 and L1. 

92 

93 time_reference: str, optional 

94 Name of the reference for the sampled time parameter. 

95 

96 - :code:`geocent`/:code:`geocenter`: sample in the time at the 

97 Earth's center, this is the default 

98 - e.g., :code:`H1`: sample in the time of arrival at H1 

99 

100 Returns 

101 ======= 

102 Likelihood: `bilby.core.likelihood.Likelihood` 

103 A likelihood object, able to compute the likelihood of the data given 

104 some model parameters 

105 

106 """ 

107 

108 @attr.s(slots=True, weakref_slot=False) 

109 class _CalculatedSNRs: 

110 d_inner_h = attr.ib(default=0j, converter=complex) 

111 optimal_snr_squared = attr.ib(default=0, converter=float) 

112 complex_matched_filter_snr = attr.ib(default=0j, converter=complex) 

113 d_inner_h_array = attr.ib(default=None) 

114 optimal_snr_squared_array = attr.ib(default=None) 

115 

116 def __add__(self, other_snr): 

117 new = copy.deepcopy(self) 

118 new += other_snr 

119 return new 

120 

121 def __iadd__(self, other_snr): 

122 for key in self.__slots__: 

123 this = getattr(self, key) 

124 other = getattr(other_snr, key) 

125 if this is not None and other is not None: 

126 setattr(self, key, this + other) 

127 elif this is None: 

128 setattr(self, key, other) 

129 return self 

130 

131 @property 

132 def snrs_as_sample(self) -> dict: 

133 """Get the SNRs of this object as a sample dictionary 

134 

135 Returns 

136 ======= 

137 dict 

138 The dictionary of SNRs labelled accordingly 

139 """ 

140 return { 

141 "matched_filter_snr" : self.complex_matched_filter_snr, 

142 "optimal_snr" : self.optimal_snr_squared.real ** 0.5 

143 } 

144 

145 def __init__( 

146 self, interferometers, waveform_generator, time_marginalization=False, 

147 distance_marginalization=False, phase_marginalization=False, calibration_marginalization=False, priors=None, 

148 distance_marginalization_lookup_table=None, calibration_lookup_table=None, 

149 number_of_response_curves=1000, starting_index=0, jitter_time=True, reference_frame="sky", 

150 time_reference="geocenter" 

151 ): 

152 

153 self.waveform_generator = waveform_generator 

154 super(GravitationalWaveTransient, self).__init__(dict()) 

155 self.interferometers = InterferometerList(interferometers) 

156 self.time_marginalization = time_marginalization 

157 self.distance_marginalization = distance_marginalization 

158 self.phase_marginalization = phase_marginalization 

159 self.calibration_marginalization = calibration_marginalization 

160 self.priors = priors 

161 self._check_set_duration_and_sampling_frequency_of_waveform_generator() 

162 self._noise_log_likelihood_value = None 

163 self.jitter_time = jitter_time 

164 self.reference_frame = reference_frame 

165 if "geocent" not in time_reference: 

166 self.time_reference = time_reference 

167 self.reference_ifo = get_empty_interferometer(self.time_reference) 

168 if self.time_marginalization: 

169 logger.info("Cannot marginalise over non-geocenter time.") 

170 self.time_marginalization = False 

171 self.jitter_time = False 

172 else: 

173 self.time_reference = "geocent" 

174 self.reference_ifo = None 

175 

176 if self.time_marginalization: 

177 self._check_marginalized_prior_is_set(key='geocent_time') 

178 self._setup_time_marginalization() 

179 priors['geocent_time'] = float(self.interferometers.start_time) 

180 if self.jitter_time: 

181 priors['time_jitter'] = Uniform( 

182 minimum=- self._delta_tc / 2, 

183 maximum=self._delta_tc / 2, 

184 boundary='periodic', 

185 name="time_jitter", 

186 latex_label="$t_j$" 

187 ) 

188 self._marginalized_parameters.append('geocent_time') 

189 elif self.jitter_time: 

190 logger.debug( 

191 "Time jittering requested with non-time-marginalised " 

192 "likelihood, ignoring.") 

193 self.jitter_time = False 

194 

195 if self.phase_marginalization: 

196 self._check_marginalized_prior_is_set(key='phase') 

197 priors['phase'] = float(0) 

198 self._marginalized_parameters.append('phase') 

199 

200 if self.distance_marginalization: 

201 self._lookup_table_filename = None 

202 self._check_marginalized_prior_is_set(key='luminosity_distance') 

203 self._distance_array = np.linspace( 

204 self.priors['luminosity_distance'].minimum, 

205 self.priors['luminosity_distance'].maximum, int(1e4)) 

206 self.distance_prior_array = np.array( 

207 [self.priors['luminosity_distance'].prob(distance) 

208 for distance in self._distance_array]) 

209 self._ref_dist = self.priors['luminosity_distance'].rescale(0.5) 

210 self._setup_distance_marginalization( 

211 distance_marginalization_lookup_table) 

212 for key in ['redshift', 'comoving_distance']: 

213 if key in priors: 

214 del priors[key] 

215 priors['luminosity_distance'] = float(self._ref_dist) 

216 self._marginalized_parameters.append('luminosity_distance') 

217 

218 if self.calibration_marginalization: 

219 self.number_of_response_curves = number_of_response_curves 

220 self.starting_index = starting_index 

221 self._setup_calibration_marginalization(calibration_lookup_table, priors) 

222 self._marginalized_parameters.append('recalib_index') 

223 

224 def __repr__(self): 

225 return self.__class__.__name__ + '(interferometers={},\n\twaveform_generator={},\n\ttime_marginalization={}, ' \ 

226 'distance_marginalization={}, phase_marginalization={}, ' \ 

227 'calibration_marginalization={}, priors={})' \ 

228 .format(self.interferometers, self.waveform_generator, self.time_marginalization, 

229 self.distance_marginalization, self.phase_marginalization, self.calibration_marginalization, 

230 self.priors) 

231 

232 def _check_set_duration_and_sampling_frequency_of_waveform_generator(self): 

233 """ Check the waveform_generator has the same duration and 

234 sampling_frequency as the interferometers. If they are unset, then 

235 set them, if they differ, raise an error 

236 """ 

237 

238 attributes = ['duration', 'sampling_frequency', 'start_time'] 

239 for attribute in attributes: 

240 wfg_attr = getattr(self.waveform_generator, attribute) 

241 ifo_attr = getattr(self.interferometers, attribute) 

242 if wfg_attr is None: 

243 logger.debug( 

244 "The waveform_generator {} is None. Setting from the " 

245 "provided interferometers.".format(attribute)) 

246 elif wfg_attr != ifo_attr: 

247 logger.debug( 

248 "The waveform_generator {} is not equal to that of the " 

249 "provided interferometers. Overwriting the " 

250 "waveform_generator.".format(attribute)) 

251 setattr(self.waveform_generator, attribute, ifo_attr) 

252 

253 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True): 

254 """ 

255 Compute the snrs 

256 

257 Parameters 

258 ---------- 

259 waveform_polarizations: dict 

260 A dictionary of waveform polarizations and the corresponding array 

261 interferometer: bilby.gw.detector.Interferometer 

262 The bilby interferometer object 

263 return_array: bool 

264 If true, calculate and return internal array objects 

265 (d_inner_h_array and optimal_snr_squared_array), otherwise 

266 these are returned as None. 

267 

268 Returns 

269 ------- 

270 calculated_snrs: _CalculatedSNRs 

271 An object containing the SNR quantities and (if return_array=True) 

272 the internal array objects. 

273 

274 """ 

275 signal = self._compute_full_waveform( 

276 signal_polarizations=waveform_polarizations, 

277 interferometer=interferometer, 

278 ) 

279 _mask = interferometer.frequency_mask 

280 

281 if 'recalib_index' in self.parameters: 

282 signal[_mask] *= self.calibration_draws[interferometer.name][int(self.parameters['recalib_index'])] 

283 

284 d_inner_h = interferometer.inner_product(signal=signal) 

285 optimal_snr_squared = interferometer.optimal_snr_squared(signal=signal) 

286 complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) 

287 

288 d_inner_h_array = None 

289 optimal_snr_squared_array = None 

290 

291 normalization = 4 / self.waveform_generator.duration 

292 

293 if return_array is False: 

294 d_inner_h_array = None 

295 optimal_snr_squared_array = None 

296 elif self.time_marginalization and self.calibration_marginalization: 

297 

298 d_inner_h_integrand = np.tile( 

299 interferometer.frequency_domain_strain.conjugate() * signal / 

300 interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T 

301 

302 d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T 

303 

304 d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( 

305 d_inner_h_integrand[0:-1], axis=0 

306 ).T 

307 

308 optimal_snr_squared_integrand = ( 

309 normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array 

310 ) 

311 optimal_snr_squared_array = np.dot( 

312 optimal_snr_squared_integrand[_mask], 

313 self.calibration_abs_draws[interferometer.name].T 

314 ) 

315 

316 elif self.time_marginalization and not self.calibration_marginalization: 

317 d_inner_h_array = normalization * np.fft.fft( 

318 signal[0:-1] 

319 * interferometer.frequency_domain_strain.conjugate()[0:-1] 

320 / interferometer.power_spectral_density_array[0:-1] 

321 ) 

322 

323 elif self.calibration_marginalization and ('recalib_index' not in self.parameters): 

324 d_inner_h_integrand = ( 

325 normalization * 

326 interferometer.frequency_domain_strain.conjugate() * signal 

327 / interferometer.power_spectral_density_array 

328 ) 

329 d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) 

330 

331 optimal_snr_squared_integrand = ( 

332 normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array 

333 ) 

334 optimal_snr_squared_array = np.dot( 

335 optimal_snr_squared_integrand[_mask], 

336 self.calibration_abs_draws[interferometer.name].T 

337 ) 

338 

339 return self._CalculatedSNRs( 

340 d_inner_h=d_inner_h, 

341 optimal_snr_squared=optimal_snr_squared.real, 

342 complex_matched_filter_snr=complex_matched_filter_snr, 

343 d_inner_h_array=d_inner_h_array, 

344 optimal_snr_squared_array=optimal_snr_squared_array, 

345 ) 

346 

347 def _check_marginalized_prior_is_set(self, key): 

348 if key in self.priors and self.priors[key].is_fixed: 

349 raise ValueError( 

350 "Cannot use marginalized likelihood for {}: prior is fixed".format(key) 

351 ) 

352 if key not in self.priors or not isinstance( 

353 self.priors[key], Prior): 

354 logger.warning( 

355 'Prior not provided for {}, using the BBH default.'.format(key)) 

356 if key == 'geocent_time': 

357 self.priors[key] = Uniform( 

358 self.interferometers.start_time, 

359 self.interferometers.start_time + self.interferometers.duration) 

360 elif key == 'luminosity_distance': 

361 for key in ['redshift', 'comoving_distance']: 

362 if key in self.priors: 

363 if not isinstance(self.priors[key], Cosmological): 

364 raise TypeError( 

365 "To marginalize over {}, the prior must be specified as a " 

366 "subclass of bilby.gw.prior.Cosmological.".format(key) 

367 ) 

368 self.priors['luminosity_distance'] = self.priors[key].get_corresponding_prior( 

369 'luminosity_distance' 

370 ) 

371 del self.priors[key] 

372 else: 

373 self.priors[key] = BBHPriorDict()[key] 

374 

375 @property 

376 def priors(self): 

377 return self._prior 

378 

379 @priors.setter 

380 def priors(self, priors): 

381 if priors is not None: 

382 self._prior = priors.copy() 

383 elif any([self.time_marginalization, self.phase_marginalization, 

384 self.distance_marginalization]): 

385 raise ValueError("You can't use a marginalized likelihood without specifying a priors") 

386 else: 

387 self._prior = None 

388 

389 def _calculate_noise_log_likelihood(self): 

390 log_l = 0 

391 for interferometer in self.interferometers: 

392 mask = interferometer.frequency_mask 

393 log_l -= noise_weighted_inner_product( 

394 interferometer.frequency_domain_strain[mask], 

395 interferometer.frequency_domain_strain[mask], 

396 interferometer.power_spectral_density_array[mask], 

397 self.waveform_generator.duration) / 2 

398 return float(np.real(log_l)) 

399 

400 def noise_log_likelihood(self): 

401 # only compute likelihood if called for the 1st time 

402 if self._noise_log_likelihood_value is None: 

403 self._noise_log_likelihood_value = self._calculate_noise_log_likelihood() 

404 return self._noise_log_likelihood_value 

405 

406 def log_likelihood_ratio(self): 

407 waveform_polarizations = \ 

408 self.waveform_generator.frequency_domain_strain(self.parameters) 

409 if waveform_polarizations is None: 

410 return np.nan_to_num(-np.inf) 

411 

412 if self.time_marginalization and self.jitter_time: 

413 self.parameters['geocent_time'] += self.parameters['time_jitter'] 

414 

415 self.parameters.update(self.get_sky_frame_parameters()) 

416 

417 total_snrs = self._CalculatedSNRs() 

418 

419 for interferometer in self.interferometers: 

420 per_detector_snr = self.calculate_snrs( 

421 waveform_polarizations=waveform_polarizations, 

422 interferometer=interferometer) 

423 

424 total_snrs += per_detector_snr 

425 

426 log_l = self.compute_log_likelihood_from_snrs(total_snrs) 

427 

428 if self.time_marginalization and self.jitter_time: 

429 self.parameters['geocent_time'] -= self.parameters['time_jitter'] 

430 

431 return float(log_l.real) 

432 

433 def compute_log_likelihood_from_snrs(self, total_snrs): 

434 

435 if self.calibration_marginalization: 

436 log_l = self.calibration_marginalized_likelihood( 

437 d_inner_h_calibration_array=total_snrs.d_inner_h_array, 

438 h_inner_h=total_snrs.optimal_snr_squared_array) 

439 

440 elif self.time_marginalization: 

441 log_l = self.time_marginalized_likelihood( 

442 d_inner_h_tc_array=total_snrs.d_inner_h_array, 

443 h_inner_h=total_snrs.optimal_snr_squared) 

444 

445 elif self.distance_marginalization: 

446 log_l = self.distance_marginalized_likelihood( 

447 d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared) 

448 

449 elif self.phase_marginalization: 

450 log_l = self.phase_marginalized_likelihood( 

451 d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared) 

452 

453 else: 

454 log_l = np.real(total_snrs.d_inner_h) - total_snrs.optimal_snr_squared / 2 

455 

456 return log_l 

457 

458 def compute_per_detector_log_likelihood(self): 

459 waveform_polarizations = \ 

460 self.waveform_generator.frequency_domain_strain(self.parameters) 

461 

462 if self.time_marginalization and self.jitter_time: 

463 self.parameters['geocent_time'] += self.parameters['time_jitter'] 

464 

465 self.parameters.update(self.get_sky_frame_parameters()) 

466 

467 for interferometer in self.interferometers: 

468 per_detector_snr = self.calculate_snrs( 

469 waveform_polarizations=waveform_polarizations, 

470 interferometer=interferometer) 

471 

472 self.parameters['{}_log_likelihood'.format(interferometer.name)] = \ 

473 self.compute_log_likelihood_from_snrs(per_detector_snr) 

474 

475 if self.time_marginalization and self.jitter_time: 

476 self.parameters['geocent_time'] -= self.parameters['time_jitter'] 

477 

478 return self.parameters.copy() 

479 

480 def generate_posterior_sample_from_marginalized_likelihood(self): 

481 """ 

482 Reconstruct the distance posterior from a run which used a likelihood 

483 which explicitly marginalised over time/distance/phase. 

484 

485 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293 

486 

487 Returns 

488 ======= 

489 sample: dict 

490 Returns the parameters with new samples. 

491 

492 Notes 

493 ===== 

494 This involves a deepcopy of the signal to avoid issues with waveform 

495 caching, as the signal is overwritten in place. 

496 """ 

497 if len(self._marginalized_parameters) > 0: 

498 signal_polarizations = copy.deepcopy( 

499 self.waveform_generator.frequency_domain_strain( 

500 self.parameters)) 

501 else: 

502 return self.parameters 

503 

504 if self.calibration_marginalization: 

505 new_calibration = self.generate_calibration_sample_from_marginalized_likelihood( 

506 signal_polarizations=signal_polarizations) 

507 self.parameters['recalib_index'] = new_calibration 

508 if self.time_marginalization: 

509 new_time = self.generate_time_sample_from_marginalized_likelihood( 

510 signal_polarizations=signal_polarizations) 

511 self.parameters['geocent_time'] = new_time 

512 if self.distance_marginalization: 

513 new_distance = self.generate_distance_sample_from_marginalized_likelihood( 

514 signal_polarizations=signal_polarizations) 

515 self.parameters['luminosity_distance'] = new_distance 

516 if self.phase_marginalization: 

517 new_phase = self.generate_phase_sample_from_marginalized_likelihood( 

518 signal_polarizations=signal_polarizations) 

519 self.parameters['phase'] = new_phase 

520 return self.parameters.copy() 

521 

522 def generate_calibration_sample_from_marginalized_likelihood( 

523 self, signal_polarizations=None): 

524 """ 

525 Generate a single sample from the posterior distribution for the set of calibration response curves when 

526 explicitly marginalizing over the calibration uncertainty. 

527 

528 Parameters 

529 ---------- 

530 signal_polarizations: dict, optional 

531 Polarizations modes of the template. 

532 

533 Returns 

534 ------- 

535 new_calibration: dict 

536 Sample set from the calibration posterior 

537 """ 

538 from ...core.utils.random import rng 

539 

540 if 'recalib_index' in self.parameters: 

541 self.parameters.pop('recalib_index') 

542 self.parameters.update(self.get_sky_frame_parameters()) 

543 if signal_polarizations is None: 

544 signal_polarizations = \ 

545 self.waveform_generator.frequency_domain_strain(self.parameters) 

546 

547 log_like = self.get_calibration_log_likelihoods(signal_polarizations=signal_polarizations) 

548 

549 calibration_post = np.exp(log_like - max(log_like)) 

550 calibration_post /= np.sum(calibration_post) 

551 

552 new_calibration = rng.choice(self.number_of_response_curves, p=calibration_post) 

553 

554 return new_calibration 

555 

556 def generate_time_sample_from_marginalized_likelihood( 

557 self, signal_polarizations=None): 

558 """ 

559 Generate a single sample from the posterior distribution for coalescence 

560 time when using a likelihood which explicitly marginalises over time. 

561 

562 In order to resolve the posterior we artificially upsample to 16kHz. 

563 

564 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293 

565 

566 Parameters 

567 ========== 

568 signal_polarizations: dict, optional 

569 Polarizations modes of the template. 

570 

571 Returns 

572 ======= 

573 new_time: float 

574 Sample from the time posterior. 

575 """ 

576 self.parameters.update(self.get_sky_frame_parameters()) 

577 if self.jitter_time: 

578 self.parameters['geocent_time'] += self.parameters['time_jitter'] 

579 if signal_polarizations is None: 

580 signal_polarizations = \ 

581 self.waveform_generator.frequency_domain_strain(self.parameters) 

582 

583 times = create_time_series( 

584 sampling_frequency=16384, 

585 starting_time=self.parameters['geocent_time'] - self.waveform_generator.start_time, 

586 duration=self.waveform_generator.duration) 

587 times = times % self.waveform_generator.duration 

588 times += self.waveform_generator.start_time 

589 

590 prior = self.priors["geocent_time"] 

591 in_prior = (times >= prior.minimum) & (times < prior.maximum) 

592 times = times[in_prior] 

593 

594 n_time_steps = int(self.waveform_generator.duration * 16384) 

595 d_inner_h = np.zeros(len(times), dtype=complex) 

596 psd = np.ones(n_time_steps) 

597 signal_long = np.zeros(n_time_steps, dtype=complex) 

598 data = np.zeros(n_time_steps, dtype=complex) 

599 h_inner_h = np.zeros(1) 

600 for ifo in self.interferometers: 

601 ifo_length = len(ifo.frequency_domain_strain) 

602 mask = ifo.frequency_mask 

603 signal = self._compute_full_waveform( 

604 signal_polarizations=signal_polarizations, 

605 interferometer=ifo, 

606 ) 

607 signal_long[:ifo_length] = signal 

608 data[:ifo_length] = np.conj(ifo.frequency_domain_strain) 

609 psd[:ifo_length][mask] = ifo.power_spectral_density_array[mask] 

610 d_inner_h += np.fft.fft(signal_long * data / psd)[in_prior] 

611 h_inner_h += ifo.optimal_snr_squared(signal=signal).real 

612 

613 if self.distance_marginalization: 

614 time_log_like = self.distance_marginalized_likelihood( 

615 d_inner_h, h_inner_h) 

616 elif self.phase_marginalization: 

617 time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2 

618 else: 

619 time_log_like = (d_inner_h.real - h_inner_h.real / 2) 

620 

621 time_prior_array = self.priors['geocent_time'].prob(times) 

622 time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array 

623 

624 keep = (time_post > max(time_post) / 1000) 

625 if sum(keep) < 3: 

626 keep[1:-1] = keep[1:-1] | keep[2:] | keep[:-2] 

627 time_post = time_post[keep] 

628 times = times[keep] 

629 

630 new_time = Interped(times, time_post).sample() 

631 return new_time 

632 

633 def generate_distance_sample_from_marginalized_likelihood( 

634 self, signal_polarizations=None): 

635 """ 

636 Generate a single sample from the posterior distribution for luminosity 

637 distance when using a likelihood which explicitly marginalises over 

638 distance. 

639 

640 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293 

641 

642 Parameters 

643 ========== 

644 signal_polarizations: dict, optional 

645 Polarizations modes of the template. 

646 Note: These are rescaled in place after the distance sample is 

647 generated to allow further parameter reconstruction to occur. 

648 

649 Returns 

650 ======= 

651 new_distance: float 

652 Sample from the distance posterior. 

653 """ 

654 self.parameters.update(self.get_sky_frame_parameters()) 

655 if signal_polarizations is None: 

656 signal_polarizations = \ 

657 self.waveform_generator.frequency_domain_strain(self.parameters) 

658 

659 d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations) 

660 

661 d_inner_h_dist = ( 

662 d_inner_h * self.parameters['luminosity_distance'] / self._distance_array 

663 ) 

664 

665 h_inner_h_dist = ( 

666 h_inner_h * self.parameters['luminosity_distance']**2 / self._distance_array**2 

667 ) 

668 

669 if self.phase_marginalization: 

670 distance_log_like = ln_i0(abs(d_inner_h_dist)) - h_inner_h_dist.real / 2 

671 else: 

672 distance_log_like = (d_inner_h_dist.real - h_inner_h_dist.real / 2) 

673 

674 distance_post = (np.exp(distance_log_like - max(distance_log_like)) * 

675 self.distance_prior_array) 

676 

677 new_distance = Interped( 

678 self._distance_array, distance_post).sample() 

679 

680 self._rescale_signal(signal_polarizations, new_distance) 

681 return new_distance 

682 

683 def _calculate_inner_products(self, signal_polarizations): 

684 d_inner_h = 0 

685 h_inner_h = 0 

686 for interferometer in self.interferometers: 

687 per_detector_snr = self.calculate_snrs( 

688 signal_polarizations, interferometer) 

689 

690 d_inner_h += per_detector_snr.d_inner_h 

691 h_inner_h += per_detector_snr.optimal_snr_squared 

692 return d_inner_h, h_inner_h 

693 

694 def _compute_full_waveform(self, signal_polarizations, interferometer): 

695 """ 

696 Project the waveform polarizations against the interferometer 

697 response. This is useful for likelihood classes that don't 

698 use the full frequency array, e.g., the relative binning 

699 likelihood. 

700 

701 Parameters 

702 ========== 

703 signal_polarizations: dict 

704 Dictionary containing the waveform evaluated at 

705 :code:`interferometer.frequency_array`. 

706 interferometer: bilby.gw.detector.Interferometer 

707 Interferometer to compute the response with respect to. 

708 """ 

709 return interferometer.get_detector_response(signal_polarizations, self.parameters) 

710 

711 def generate_phase_sample_from_marginalized_likelihood( 

712 self, signal_polarizations=None): 

713 r""" 

714 Generate a single sample from the posterior distribution for phase when 

715 using a likelihood which explicitly marginalises over phase. 

716 

717 See Eq. (C29-C32) of https://arxiv.org/abs/1809.02293 

718 

719 Parameters 

720 ========== 

721 signal_polarizations: dict, optional 

722 Polarizations modes of the template. 

723 

724 Returns 

725 ======= 

726 new_phase: float 

727 Sample from the phase posterior. 

728 

729 Notes 

730 ===== 

731 This is only valid when assumes that mu(phi) \propto exp(-2i phi). 

732 """ 

733 self.parameters.update(self.get_sky_frame_parameters()) 

734 if signal_polarizations is None: 

735 signal_polarizations = \ 

736 self.waveform_generator.frequency_domain_strain(self.parameters) 

737 d_inner_h, h_inner_h = self._calculate_inner_products(signal_polarizations) 

738 

739 phases = np.linspace(0, 2 * np.pi, 101) 

740 phasor = np.exp(-2j * phases) 

741 phase_log_post = d_inner_h * phasor - h_inner_h / 2 

742 phase_post = np.exp(phase_log_post.real - max(phase_log_post.real)) 

743 new_phase = Interped(phases, phase_post).sample() 

744 return new_phase 

745 

746 def distance_marginalized_likelihood(self, d_inner_h, h_inner_h): 

747 d_inner_h_ref, h_inner_h_ref = self._setup_rho( 

748 d_inner_h, h_inner_h) 

749 if self.phase_marginalization: 

750 d_inner_h_ref = np.abs(d_inner_h_ref) 

751 else: 

752 d_inner_h_ref = np.real(d_inner_h_ref) 

753 

754 return self._interp_dist_margd_loglikelihood( 

755 d_inner_h_ref, h_inner_h_ref, grid=False) 

756 

757 def phase_marginalized_likelihood(self, d_inner_h, h_inner_h): 

758 d_inner_h = ln_i0(abs(d_inner_h)) 

759 

760 if self.calibration_marginalization and self.time_marginalization: 

761 return d_inner_h - h_inner_h[:, np.newaxis] / 2 

762 else: 

763 return d_inner_h - h_inner_h / 2 

764 

765 def time_marginalized_likelihood(self, d_inner_h_tc_array, h_inner_h): 

766 times = self._times 

767 if self.jitter_time: 

768 times = self._times + self.parameters['time_jitter'] 

769 

770 _time_prior = self.priors['geocent_time'] 

771 time_mask = (times >= _time_prior.minimum) & (times <= _time_prior.maximum) 

772 times = times[time_mask] 

773 time_prior_array = self.priors['geocent_time'].prob(times) * self._delta_tc 

774 if self.calibration_marginalization: 

775 d_inner_h_tc_array = d_inner_h_tc_array[:, time_mask] 

776 else: 

777 d_inner_h_tc_array = d_inner_h_tc_array[time_mask] 

778 

779 if self.distance_marginalization: 

780 log_l_tc_array = self.distance_marginalized_likelihood( 

781 d_inner_h=d_inner_h_tc_array, h_inner_h=h_inner_h) 

782 elif self.phase_marginalization: 

783 log_l_tc_array = self.phase_marginalized_likelihood( 

784 d_inner_h=d_inner_h_tc_array, 

785 h_inner_h=h_inner_h) 

786 elif self.calibration_marginalization: 

787 log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h[:, np.newaxis] / 2 

788 else: 

789 log_l_tc_array = np.real(d_inner_h_tc_array) - h_inner_h / 2 

790 return logsumexp(log_l_tc_array, b=time_prior_array, axis=-1) 

791 

792 def get_calibration_log_likelihoods(self, signal_polarizations=None): 

793 self.parameters.update(self.get_sky_frame_parameters()) 

794 if signal_polarizations is None: 

795 signal_polarizations = \ 

796 self.waveform_generator.frequency_domain_strain(self.parameters) 

797 

798 total_snrs = self._CalculatedSNRs() 

799 

800 for interferometer in self.interferometers: 

801 per_detector_snr = self.calculate_snrs( 

802 waveform_polarizations=signal_polarizations, 

803 interferometer=interferometer) 

804 

805 total_snrs += per_detector_snr 

806 

807 if self.time_marginalization: 

808 log_l_cal_array = self.time_marginalized_likelihood( 

809 d_inner_h_tc_array=total_snrs.d_inner_h_array, 

810 h_inner_h=total_snrs.optimal_snr_squared_array, 

811 ) 

812 elif self.distance_marginalization: 

813 log_l_cal_array = self.distance_marginalized_likelihood( 

814 d_inner_h=total_snrs.d_inner_h_array, 

815 h_inner_h=total_snrs.optimal_snr_squared_array) 

816 elif self.phase_marginalization: 

817 log_l_cal_array = self.phase_marginalized_likelihood( 

818 d_inner_h=total_snrs.d_inner_h_array, 

819 h_inner_h=total_snrs.optimal_snr_squared_array) 

820 else: 

821 log_l_cal_array = \ 

822 np.real(total_snrs.d_inner_h_array - total_snrs.optimal_snr_squared_array / 2) 

823 

824 return log_l_cal_array 

825 

826 def calibration_marginalized_likelihood(self, d_inner_h_calibration_array, h_inner_h): 

827 if self.time_marginalization: 

828 log_l_cal_array = self.time_marginalized_likelihood( 

829 d_inner_h_tc_array=d_inner_h_calibration_array, 

830 h_inner_h=h_inner_h, 

831 ) 

832 elif self.distance_marginalization: 

833 log_l_cal_array = self.distance_marginalized_likelihood( 

834 d_inner_h=d_inner_h_calibration_array, h_inner_h=h_inner_h) 

835 elif self.phase_marginalization: 

836 log_l_cal_array = self.phase_marginalized_likelihood( 

837 d_inner_h=d_inner_h_calibration_array, 

838 h_inner_h=h_inner_h) 

839 else: 

840 log_l_cal_array = np.real(d_inner_h_calibration_array - h_inner_h / 2) 

841 

842 return logsumexp(log_l_cal_array) - np.log(self.number_of_response_curves) 

843 

844 def _setup_rho(self, d_inner_h, optimal_snr_squared): 

845 optimal_snr_squared_ref = (optimal_snr_squared.real * 

846 self.parameters['luminosity_distance'] ** 2 / 

847 self._ref_dist ** 2.) 

848 d_inner_h_ref = (d_inner_h * self.parameters['luminosity_distance'] / 

849 self._ref_dist) 

850 return d_inner_h_ref, optimal_snr_squared_ref 

851 

852 def log_likelihood(self): 

853 return self.log_likelihood_ratio() + self.noise_log_likelihood() 

854 

855 @property 

856 def _delta_distance(self): 

857 return self._distance_array[1] - self._distance_array[0] 

858 

859 @property 

860 def _dist_multiplier(self): 

861 ''' Maximum value of ref_dist/dist_array ''' 

862 return self._ref_dist / self._distance_array[0] 

863 

864 @property 

865 def _optimal_snr_squared_ref_array(self): 

866 """ Optimal filter snr at fiducial distance of ref_dist Mpc """ 

867 return np.logspace(-5, 10, self._dist_margd_loglikelihood_array.shape[0]) 

868 

869 @property 

870 def _d_inner_h_ref_array(self): 

871 """ Matched filter snr at fiducial distance of ref_dist Mpc """ 

872 if self.phase_marginalization: 

873 return np.logspace(-5, 10, self._dist_margd_loglikelihood_array.shape[1]) 

874 else: 

875 n_negative = self._dist_margd_loglikelihood_array.shape[1] // 2 

876 n_positive = self._dist_margd_loglikelihood_array.shape[1] - n_negative 

877 return np.hstack(( 

878 -np.logspace(3, -3, n_negative), np.logspace(-3, 10, n_positive) 

879 )) 

880 

881 def _setup_distance_marginalization(self, lookup_table=None): 

882 if isinstance(lookup_table, str) or lookup_table is None: 

883 self.cached_lookup_table_filename = lookup_table 

884 lookup_table = self.load_lookup_table( 

885 self.cached_lookup_table_filename) 

886 if isinstance(lookup_table, dict): 

887 if self._test_cached_lookup_table(lookup_table): 

888 self._dist_margd_loglikelihood_array = lookup_table[ 

889 'lookup_table'] 

890 else: 

891 self._create_lookup_table() 

892 else: 

893 self._create_lookup_table() 

894 self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline( 

895 self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array, 

896 self._dist_margd_loglikelihood_array.T, fill_value=-np.inf) 

897 

898 @property 

899 def cached_lookup_table_filename(self): 

900 if self._lookup_table_filename is None: 

901 self._lookup_table_filename = ( 

902 '.distance_marginalization_lookup.npz') 

903 return self._lookup_table_filename 

904 

905 @cached_lookup_table_filename.setter 

906 def cached_lookup_table_filename(self, filename): 

907 if isinstance(filename, str): 

908 if filename[-4:] != '.npz': 

909 filename += '.npz' 

910 self._lookup_table_filename = filename 

911 

912 def load_lookup_table(self, filename): 

913 if os.path.exists(filename): 

914 try: 

915 loaded_file = dict(np.load(filename)) 

916 except AttributeError as e: 

917 logger.warning(e) 

918 self._create_lookup_table() 

919 return None 

920 match, failure = self._test_cached_lookup_table(loaded_file) 

921 if match: 

922 logger.info('Loaded distance marginalisation lookup table from ' 

923 '{}.'.format(filename)) 

924 return loaded_file 

925 else: 

926 logger.info('Loaded distance marginalisation lookup table does ' 

927 'not match for {}.'.format(failure)) 

928 elif isinstance(filename, str): 

929 logger.info('Distance marginalisation file {} does not ' 

930 'exist'.format(filename)) 

931 return None 

932 

933 def cache_lookup_table(self): 

934 np.savez(self.cached_lookup_table_filename, 

935 distance_array=self._distance_array, 

936 prior_array=self.distance_prior_array, 

937 lookup_table=self._dist_margd_loglikelihood_array, 

938 reference_distance=self._ref_dist, 

939 phase_marginalization=self.phase_marginalization) 

940 

941 def _test_cached_lookup_table(self, loaded_file): 

942 pairs = dict( 

943 distance_array=self._distance_array, 

944 prior_array=self.distance_prior_array, 

945 reference_distance=self._ref_dist, 

946 phase_marginalization=self.phase_marginalization) 

947 for key in pairs: 

948 if key not in loaded_file: 

949 return False, key 

950 elif not np.allclose(np.atleast_1d(loaded_file[key]), 

951 np.atleast_1d(pairs[key]), 

952 rtol=1e-15): 

953 return False, key 

954 return True, None 

955 

956 def _create_lookup_table(self): 

957 """ Make the lookup table """ 

958 from tqdm.auto import tqdm 

959 logger.info('Building lookup table for distance marginalisation.') 

960 

961 self._dist_margd_loglikelihood_array = np.zeros((400, 800)) 

962 scaling = self._ref_dist / self._distance_array 

963 d_inner_h_array_full = np.outer(self._d_inner_h_ref_array, scaling) 

964 h_inner_h_array_full = np.outer(self._optimal_snr_squared_ref_array, scaling ** 2) 

965 if self.phase_marginalization: 

966 d_inner_h_array_full = ln_i0(abs(d_inner_h_array_full)) 

967 prior_term = self.distance_prior_array * self._delta_distance 

968 for ii, optimal_snr_squared_array in tqdm( 

969 enumerate(h_inner_h_array_full), total=len(self._optimal_snr_squared_ref_array) 

970 ): 

971 for jj, d_inner_h_array in enumerate(d_inner_h_array_full): 

972 self._dist_margd_loglikelihood_array[ii][jj] = logsumexp( 

973 d_inner_h_array - optimal_snr_squared_array / 2, 

974 b=prior_term 

975 ) 

976 log_norm = logsumexp( 

977 0 / self._distance_array, b=self.distance_prior_array * self._delta_distance 

978 ) 

979 self._dist_margd_loglikelihood_array -= log_norm 

980 self.cache_lookup_table() 

981 

982 def _setup_phase_marginalization(self, min_bound=-5, max_bound=10): 

983 logger.warning( 

984 "The _setup_phase_marginalization method is deprecated and will be removed, " 

985 "please update the implementation of phase marginalization " 

986 "to use bilby.gw.utils.ln_i0" 

987 ) 

988 

989 def _setup_time_marginalization(self): 

990 self._delta_tc = 2 / self.waveform_generator.sampling_frequency 

991 self._times = \ 

992 self.interferometers.start_time + np.linspace( 

993 0, self.interferometers.duration, 

994 int(self.interferometers.duration / 2 * 

995 self.waveform_generator.sampling_frequency + 1))[1:] 

996 self.time_prior_array = \ 

997 self.priors['geocent_time'].prob(self._times) * self._delta_tc 

998 

999 def _setup_calibration_marginalization(self, calibration_lookup_table, priors=None): 

1000 self.calibration_draws, self.calibration_parameter_draws = calibration.build_calibration_lookup( 

1001 interferometers=self.interferometers, 

1002 lookup_files=calibration_lookup_table, 

1003 priors=priors, 

1004 number_of_response_curves=self.number_of_response_curves, 

1005 starting_index=self.starting_index, 

1006 ) 

1007 for name, parameters in self.calibration_parameter_draws.items(): 

1008 if parameters is not None: 

1009 for key in set(parameters.keys()).intersection(priors.keys()): 

1010 priors[key] = DeltaFunction(0.0) 

1011 self.calibration_abs_draws = dict() 

1012 for name in self.calibration_draws: 

1013 self.calibration_abs_draws[name] = np.abs(self.calibration_draws[name])**2 

1014 

1015 @property 

1016 def interferometers(self): 

1017 return self._interferometers 

1018 

1019 @interferometers.setter 

1020 def interferometers(self, interferometers): 

1021 self._interferometers = InterferometerList(interferometers) 

1022 

1023 def _rescale_signal(self, signal, new_distance): 

1024 for mode in signal: 

1025 signal[mode] *= self._ref_dist / new_distance 

1026 

1027 @property 

1028 def reference_frame(self): 

1029 return self._reference_frame 

1030 

1031 @property 

1032 def _reference_frame_str(self): 

1033 if isinstance(self.reference_frame, str): 

1034 return self.reference_frame 

1035 else: 

1036 return "".join([ifo.name for ifo in self.reference_frame]) 

1037 

1038 @reference_frame.setter 

1039 def reference_frame(self, frame): 

1040 if frame == "sky": 

1041 self._reference_frame = frame 

1042 elif isinstance(frame, InterferometerList): 

1043 self._reference_frame = frame[:2] 

1044 elif isinstance(frame, list): 

1045 self._reference_frame = InterferometerList(frame[:2]) 

1046 elif isinstance(frame, str): 

1047 self._reference_frame = InterferometerList([frame[:2], frame[2:4]]) 

1048 else: 

1049 raise ValueError("Unable to parse reference frame {}".format(frame)) 

1050 

1051 def get_sky_frame_parameters(self, parameters=None): 

1052 """ 

1053 Generate ra, dec, and geocenter time for :code:`parameters` 

1054 

1055 This method will attempt to convert from the reference time and sky 

1056 parameters, but if they are not present it will fall back to ra and dec. 

1057 

1058 Parameters 

1059 ========== 

1060 parameters: dict, optional 

1061 The parameters to be converted. 

1062 If not specified :code:`self.parameters` will be used. 

1063 

1064 Returns 

1065 ======= 

1066 dict: dictionary containing ra, dec, and geocent_time 

1067 """ 

1068 if parameters is None: 

1069 parameters = self.parameters 

1070 time = parameters.get(f'{self.time_reference}_time', None) 

1071 if time is None and "geocent_time" in parameters: 

1072 logger.warning( 

1073 f"Cannot find {self.time_reference}_time in parameters. " 

1074 "Falling back to geocent time" 

1075 ) 

1076 if not self.reference_frame == "sky": 

1077 try: 

1078 ra, dec = zenith_azimuth_to_ra_dec( 

1079 parameters['zenith'], parameters['azimuth'], 

1080 time, self.reference_frame) 

1081 except KeyError: 

1082 if "ra" in parameters and "dec" in parameters: 

1083 ra = parameters["ra"] 

1084 dec = parameters["dec"] 

1085 logger.warning( 

1086 "Cannot convert from zenith/azimuth to ra/dec falling " 

1087 "back to provided ra/dec" 

1088 ) 

1089 else: 

1090 raise 

1091 else: 

1092 ra = parameters["ra"] 

1093 dec = parameters["dec"] 

1094 if "geocent" not in self.time_reference and f"{self.time_reference}_time" in parameters: 

1095 geocent_time = time - self.reference_ifo.time_delay_from_geocenter( 

1096 ra=ra, dec=dec, time=time 

1097 ) 

1098 else: 

1099 geocent_time = parameters["geocent_time"] 

1100 return dict(ra=ra, dec=dec, geocent_time=geocent_time) 

1101 

1102 @property 

1103 def lal_version(self): 

1104 try: 

1105 from lal import git_version, __version__ 

1106 lal_version = str(__version__) 

1107 logger.info("Using lal version {}".format(lal_version)) 

1108 lal_git_version = str(git_version.verbose_msg).replace("\n", ";") 

1109 logger.info("Using lal git version {}".format(lal_git_version)) 

1110 return "lal_version={}, lal_git_version={}".format(lal_version, lal_git_version) 

1111 except (ImportError, AttributeError): 

1112 return "N/A" 

1113 

1114 @property 

1115 def lalsimulation_version(self): 

1116 try: 

1117 from lalsimulation import git_version, __version__ 

1118 lalsim_version = str(__version__) 

1119 logger.info("Using lalsimulation version {}".format(lalsim_version)) 

1120 lalsim_git_version = str(git_version.verbose_msg).replace("\n", ";") 

1121 logger.info("Using lalsimulation git version {}".format(lalsim_git_version)) 

1122 return "lalsimulation_version={}, lalsimulation_git_version={}".format(lalsim_version, lalsim_git_version) 

1123 except (ImportError, AttributeError): 

1124 return "N/A" 

1125 

1126 @property 

1127 def meta_data(self): 

1128 return dict( 

1129 interferometers=self.interferometers.meta_data, 

1130 time_marginalization=self.time_marginalization, 

1131 phase_marginalization=self.phase_marginalization, 

1132 distance_marginalization=self.distance_marginalization, 

1133 calibration_marginalization=self.calibration_marginalization, 

1134 waveform_generator_class=self.waveform_generator.__class__, 

1135 waveform_arguments=self.waveform_generator.waveform_arguments, 

1136 frequency_domain_source_model=self.waveform_generator.frequency_domain_source_model, 

1137 time_domain_source_model=self.waveform_generator.time_domain_source_model, 

1138 parameter_conversion=self.waveform_generator.parameter_conversion, 

1139 sampling_frequency=self.waveform_generator.sampling_frequency, 

1140 duration=self.waveform_generator.duration, 

1141 start_time=self.waveform_generator.start_time, 

1142 time_reference=self.time_reference, 

1143 reference_frame=self._reference_frame_str, 

1144 lal_version=self.lal_version, 

1145 lalsimulation_version=self.lalsimulation_version)