Coverage for bilby/gw/likelihood/roq.py: 92%

610 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1 

2import json 

3 

4import numpy as np 

5 

6from .base import GravitationalWaveTransient 

7from ...core.utils import BilbyJsonEncoder, decode_bilby_json 

8from ...core.utils import ( 

9 logger, create_frequency_series, speed_of_light, radius_of_earth 

10) 

11from ..prior import CBCPriorDict 

12from ..utils import ln_i0 

13 

14 

15class ROQGravitationalWaveTransient(GravitationalWaveTransient): 

16 """A reduced order quadrature likelihood object 

17 

18 This uses the method described in Smith et al., (2016) Phys. Rev. D 94, 

19 044031. A public repository of the ROQ data is available from 

20 https://git.ligo.org/lscsoft/ROQ_data. 

21 

22 Parameters 

23 ========== 

24 interferometers: list, bilby.gw.detector.InterferometerList 

25 A list of `bilby.detector.Interferometer` instances - contains the 

26 detector data and power spectral densities 

27 waveform_generator: `bilby.waveform_generator.WaveformGenerator` 

28 An object which computes the frequency-domain strain of the signal, 

29 given some set of parameters 

30 linear_matrix: str, array_like 

31 Either a string point to the file from which to load the linear_matrix 

32 array, or the array itself. 

33 quadratic_matrix: str, array_like 

34 Either a string point to the file from which to load the 

35 quadratic_matrix array, or the array itself. 

36 roq_params: str, array_like 

37 Parameters describing the domain of validity of the ROQ basis. 

38 roq_params_check: bool 

39 If true, run tests using the roq_params to check the prior and data are 

40 valid for the ROQ 

41 roq_scale_factor: float 

42 The ROQ scale factor used. 

43 parameter_conversion: func, optional 

44 Function to update self.parameters before bases are selected based on 

45 the values of self.parameters. This enables a user to switch bases 

46 based on the values of parameters which are not directly used for 

47 sampling. 

48 priors: dict, bilby.prior.PriorDict 

49 A dictionary of priors containing at least the geocent_time prior 

50 Warning: when using marginalisation the dict is overwritten which will change the 

51 the dict you are passing in. If this behaviour is undesired, pass `priors.copy()`. 

52 time_marginalization: bool, optional 

53 If true, marginalize over time in the likelihood. 

54 The spacing of time samples can be specified through delta_tc. 

55 If using time marginalisation and jitter_time is True a "jitter" 

56 parameter is added to the prior which modifies the position of the 

57 grid of times. 

58 jitter_time: bool, optional 

59 Whether to introduce a `time_jitter` parameter. This avoids either 

60 missing the likelihood peak, or introducing biases in the 

61 reconstructed time posterior due to an insufficient sampling frequency. 

62 Default is False, however using this parameter is strongly encouraged. 

63 delta_tc: float, optional 

64 The spacing of time samples for time marginalization. If not specified, 

65 it is determined based on the signal-to-noise ratio of signal. 

66 distance_marginalization_lookup_table: (dict, str), optional 

67 If a dict, dictionary containing the lookup_table, distance_array, 

68 (distance) prior_array, and reference_distance used to construct 

69 the table. 

70 If a string the name of a file containing these quantities. 

71 The lookup table is stored after construction in either the 

72 provided string or a default location: 

73 '.distance_marginalization_lookup_dmin{}_dmax{}_n{}.npz' 

74 reference_frame: (str, bilby.gw.detector.InterferometerList, list), optional 

75 Definition of the reference frame for the sky location. 

76 - "sky": sample in RA/dec, this is the default 

77 - e.g., "H1L1", ["H1", "L1"], InterferometerList(["H1", "L1"]): 

78 sample in azimuth and zenith, `azimuth` and `zenith` defined in the 

79 frame where the z-axis is aligned the the vector connecting H1 

80 and L1. 

81 time_reference: str, optional 

82 Name of the reference for the sampled time parameter. 

83 - "geocent"/"geocenter": sample in the time at the Earth's center, 

84 this is the default 

85 - e.g., "H1": sample in the time of arrival at H1 

86 

87 """ 

88 def __init__( 

89 self, interferometers, waveform_generator, priors, 

90 weights=None, linear_matrix=None, quadratic_matrix=None, 

91 roq_params=None, roq_params_check=True, roq_scale_factor=1, 

92 distance_marginalization=False, phase_marginalization=False, 

93 time_marginalization=False, jitter_time=True, delta_tc=None, 

94 distance_marginalization_lookup_table=None, 

95 reference_frame="sky", time_reference="geocenter", 

96 parameter_conversion=None 

97 

98 ): 

99 self._delta_tc = delta_tc 

100 super(ROQGravitationalWaveTransient, self).__init__( 

101 interferometers=interferometers, 

102 waveform_generator=waveform_generator, priors=priors, 

103 distance_marginalization=distance_marginalization, 

104 phase_marginalization=phase_marginalization, 

105 time_marginalization=time_marginalization, 

106 distance_marginalization_lookup_table=distance_marginalization_lookup_table, 

107 jitter_time=jitter_time, 

108 reference_frame=reference_frame, 

109 time_reference=time_reference 

110 ) 

111 

112 self.roq_params_check = roq_params_check 

113 self.roq_scale_factor = roq_scale_factor 

114 if isinstance(roq_params, np.ndarray) or roq_params is None: 

115 self.roq_params = roq_params 

116 elif isinstance(roq_params, str): 

117 self.roq_params_file = roq_params 

118 self.roq_params = np.genfromtxt(roq_params, names=True) 

119 else: 

120 raise TypeError("roq_params should be array or str") 

121 if isinstance(weights, dict): 

122 self.weights = weights 

123 elif isinstance(weights, str): 

124 self.weights = self.load_weights(weights) 

125 else: 

126 is_hdf5_linear = isinstance(linear_matrix, str) and linear_matrix.endswith('.hdf5') 

127 linear_matrix = self._parse_basis(linear_matrix, 'linear') 

128 is_hdf5_quadratic = isinstance(quadratic_matrix, str) and quadratic_matrix.endswith('.hdf5') 

129 quadratic_matrix = self._parse_basis(quadratic_matrix, 'quadratic') 

130 # retrieve roq params from a basis file if it is .hdf5 

131 if self.roq_params is None: 

132 if is_hdf5_linear: 

133 self.roq_params = np.array( 

134 [(linear_matrix['minimum_frequency_hz'][()], 

135 linear_matrix['maximum_frequency_hz'][()], 

136 linear_matrix['duration_s'][()])], 

137 dtype=[('flow', float), ('fhigh', float), ('seglen', float)] 

138 ) 

139 if is_hdf5_quadratic: 

140 if self.roq_params is None: 

141 self.roq_params = np.array( 

142 [(quadratic_matrix['minimum_frequency_hz'][()], 

143 quadratic_matrix['maximum_frequency_hz'][()], 

144 quadratic_matrix['duration_s'][()])], 

145 dtype=[('flow', float), ('fhigh', float), ('seglen', float)] 

146 ) 

147 else: 

148 self.roq_params['flow'] = max( 

149 self.roq_params['flow'], quadratic_matrix['minimum_frequency_hz'][()] 

150 ) 

151 self.roq_params['fhigh'] = min( 

152 self.roq_params['fhigh'], quadratic_matrix['maximum_frequency_hz'][()] 

153 ) 

154 self.roq_params['seglen'] = min( 

155 self.roq_params['seglen'], quadratic_matrix['duration_s'][()] 

156 ) 

157 if self.roq_params is not None: 

158 for ifo in self.interferometers: 

159 self.perform_roq_params_check(ifo) 

160 

161 self.weights = dict() 

162 self._set_weights(linear_matrix=linear_matrix, quadratic_matrix=quadratic_matrix) 

163 if is_hdf5_linear: 

164 linear_matrix.close() 

165 if is_hdf5_quadratic: 

166 quadratic_matrix.close() 

167 

168 self.number_of_bases_linear = len(self.weights[f'{self.interferometers[0].name}_linear']) 

169 self.number_of_bases_quadratic = len(self.weights[f'{self.interferometers[0].name}_quadratic']) 

170 self._cache = dict(parameters=None, basis_number_linear=None, basis_number_quadratic=None) 

171 self.parameter_conversion = parameter_conversion 

172 

173 for basis_type in ['linear', 'quadratic']: 

174 number_of_bases = getattr(self, f'number_of_bases_{basis_type}') 

175 if number_of_bases > 1: 

176 self._verify_numbers_of_prior_ranges_and_frequency_nodes(basis_type) 

177 else: 

178 self._check_frequency_nodes_exist_for_single_basis(basis_type) 

179 self._verify_prior_ranges(basis_type) 

180 

181 self._set_unique_frequency_nodes_and_inverse() 

182 # need to fill waveform_arguments here if single basis is used, as they will never be updated. 

183 if self.number_of_bases_linear == 1 and self.number_of_bases_quadratic == 1: 

184 frequency_nodes, linear_indices, quadratic_indices = \ 

185 self._unique_frequency_nodes_and_inverse[0][0] 

186 self._waveform_generator.waveform_arguments['frequency_nodes'] = frequency_nodes 

187 self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices 

188 self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices 

189 

190 def _verify_numbers_of_prior_ranges_and_frequency_nodes(self, basis_type): 

191 """ 

192 Check if self.weights contains lists of prior ranges and frequency nodes, and their sizes are equal to the 

193 number of bases. 

194 

195 Parameters 

196 ========== 

197 basis_type: str 

198 

199 """ 

200 number_of_bases = getattr(self, f'number_of_bases_{basis_type}') 

201 key = f'prior_range_{basis_type}' 

202 try: 

203 prior_ranges = self.weights[key] 

204 except KeyError: 

205 raise AttributeError( 

206 f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".') 

207 else: 

208 for param_name in prior_ranges: 

209 if len(prior_ranges[param_name]) != number_of_bases: 

210 raise ValueError( 

211 f'The number of prior ranges for "{param_name}" does not ' 

212 f'match the number of {basis_type} bases') 

213 key = f'frequency_nodes_{basis_type}' 

214 try: 

215 frequency_nodes = self.weights[key] 

216 except KeyError: 

217 raise AttributeError( 

218 f'For the use of multiple {basis_type} ROQ bases, weights should contain "{key}".') 

219 else: 

220 if len(frequency_nodes) != number_of_bases: 

221 raise ValueError( 

222 f'The number of arrays of frequency nodes does not match the number of {basis_type} bases') 

223 

224 def _verify_prior_ranges(self, basis_type): 

225 """Check if the union of prior ranges is within the ROQ basis bounds. 

226 

227 Parameters 

228 ========== 

229 basis_type: str 

230 

231 """ 

232 key = f'prior_range_{basis_type}' 

233 if key not in self.weights: 

234 return 

235 prior_ranges = self.weights[key] 

236 for param_name, prior_ranges_of_this_param in prior_ranges.items(): 

237 prior_minimum = self.priors[param_name].minimum 

238 basis_minimum = np.min(prior_ranges_of_this_param[:, 0]) 

239 if prior_minimum < basis_minimum: 

240 raise BilbyROQParamsRangeError( 

241 f"Prior minimum of {param_name} {prior_minimum} less " 

242 f"than ROQ basis bound {basis_minimum}" 

243 ) 

244 

245 prior_maximum = self.priors[param_name].maximum 

246 basis_maximum = np.max(prior_ranges_of_this_param[:, 1]) 

247 if prior_maximum > basis_maximum: 

248 raise BilbyROQParamsRangeError( 

249 f"Prior maximum of {param_name} {prior_maximum} greater " 

250 f"than ROQ basis bound {basis_maximum}" 

251 ) 

252 

253 def _check_frequency_nodes_exist_for_single_basis(self, basis_type): 

254 """ 

255 For a single-basis case, frequency nodes should be contained in self._waveform_generator.waveform_arguments or 

256 self.weights. This method checks if it is the case and raise AttributeError if not. This method also adds 

257 frequency nodes to self._waveform_generator.waveform_arguments or self.weights from the other. 

258 

259 Parameters 

260 ========== 

261 basis_type: str 

262 

263 """ 

264 key = f'frequency_nodes_{basis_type}' 

265 if not (key in self.weights or key in self._waveform_generator.waveform_arguments): 

266 raise AttributeError(f'{key} should be contained in weights or waveform arguments.') 

267 elif key not in self._waveform_generator.waveform_arguments: 

268 self._waveform_generator.waveform_arguments[key] = self.weights[key][0] 

269 elif key not in self.weights: 

270 self.weights[key] = [self._waveform_generator.waveform_arguments[key]] 

271 

272 def _set_unique_frequency_nodes_and_inverse(self): 

273 """Set unique frequency nodes and indices to recover linear and quadratic frequency nodes for each combination 

274 of linear and quadratic bases 

275 """ 

276 self._unique_frequency_nodes_and_inverse = [] 

277 for idx_linear in range(self.number_of_bases_linear): 

278 tmp = [] 

279 frequency_nodes_linear = self.weights['frequency_nodes_linear'][idx_linear] 

280 size_linear = len(frequency_nodes_linear) 

281 for idx_quadratic in range(self.number_of_bases_quadratic): 

282 frequency_nodes_quadratic = self.weights['frequency_nodes_quadratic'][idx_quadratic] 

283 frequency_nodes_unique, original_indices = np.unique( 

284 np.hstack((frequency_nodes_linear, frequency_nodes_quadratic)), 

285 return_inverse=True 

286 ) 

287 linear_indices = original_indices[:size_linear] 

288 quadratic_indices = original_indices[size_linear:] 

289 tmp.append( 

290 (frequency_nodes_unique, linear_indices, quadratic_indices) 

291 ) 

292 self._unique_frequency_nodes_and_inverse.append(tmp) 

293 

294 def _setup_time_marginalization(self): 

295 if self._delta_tc is None: 

296 self._delta_tc = self._get_time_resolution() 

297 tcmin = self.priors['geocent_time'].minimum 

298 tcmax = self.priors['geocent_time'].maximum 

299 number_of_time_samples = int(np.ceil((tcmax - tcmin) / self._delta_tc)) 

300 # adjust delta tc so that the last time sample has an equal weight 

301 self._delta_tc = (tcmax - tcmin) / number_of_time_samples 

302 logger.info( 

303 "delta tc for time marginalization = {} seconds.".format(self._delta_tc)) 

304 self._times = tcmin + self._delta_tc / 2. + np.arange(number_of_time_samples) * self._delta_tc 

305 self._beam_pattern_reference_time = (tcmin + tcmax) / 2. 

306 

307 @staticmethod 

308 def _parse_basis(basis, basis_type): 

309 """ 

310 Parse basis and format it to an hdf5-like object 

311 

312 Parameters 

313 ---------- 

314 basis : array-like or str 

315 array-like basis or path to file 

316 basis_type : str 

317 'linear' or 'quadratic' 

318 

319 Returns 

320 ------- 

321 basis : hdf5-like object 

322 

323 """ 

324 if basis_type not in ['linear', 'quadratic']: 

325 raise ValueError(f'basis_type {basis_type} not recognized') 

326 if isinstance(basis, str): 

327 logger.info(f'Loading {basis_type}_matrix from {basis}') 

328 format = basis.split('.')[-1] 

329 if format == 'npy': 

330 basis = {f'basis_{basis_type}': {'0': {'basis': np.load(basis)}}} 

331 elif format == 'hdf5': 

332 import h5py 

333 basis = h5py.File(basis, 'r') 

334 else: 

335 raise IOError(f'Format {format} not recognized.') 

336 elif isinstance(basis, np.ndarray): 

337 basis = {f'basis_{basis_type}': {'0': {'basis': basis.T}}} 

338 else: 

339 raise TypeError('basis needs to be str or np.ndarray') 

340 return basis 

341 

342 def _select_prior_ranges(self, prior_ranges): 

343 """ 

344 Select prior ranges which have intersection with self.priors 

345 

346 Parameters 

347 ---------- 

348 prior_ranges : dict 

349 dictionary whose keys are parameter names and values are ndarray of 

350 their prior ranges 

351 

352 Returns 

353 ------- 

354 idxs_in_prior_range : ndarray 

355 indexes of selected prior ranges 

356 selected_prior_ranges : dict 

357 

358 """ 

359 param_names = list(prior_ranges.keys()) 

360 number_of_prior_ranges = len(prior_ranges[param_names[0]]) 

361 in_prior_range = np.ones(number_of_prior_ranges, dtype=bool) 

362 for param_name in param_names: 

363 try: 

364 prior = self.priors[param_name] 

365 except KeyError: 

366 continue 

367 prior_ranges_of_this_param = prior_ranges[param_name] 

368 in_prior_range *= \ 

369 (prior_ranges_of_this_param[:, 1] >= prior.minimum) * \ 

370 (prior_ranges_of_this_param[:, 0] <= prior.maximum) 

371 idxs_in_prior_range = np.arange(number_of_prior_ranges)[in_prior_range] 

372 return idxs_in_prior_range, \ 

373 dict((param_name, prior_ranges[param_name][idxs_in_prior_range]) 

374 for param_name in param_names) 

375 

376 def _update_basis(self): 

377 """ 

378 Update basis and frequency nodes depending on the curret values of parameters 

379 

380 This updates 

381 - self._cache 

382 - frequency_nodes_linear/quadratic in self._waveform_generator.waveform_arguments 

383 

384 """ 

385 parameters = self.parameters.copy() 

386 if self.parameter_conversion is not None: 

387 parameters = self.parameter_conversion(parameters) 

388 for basis_type, number_of_bases in zip( 

389 ['linear', 'quadratic'], [self.number_of_bases_linear, self.number_of_bases_quadratic] 

390 ): 

391 basis_number_key = f'basis_number_{basis_type}' 

392 if number_of_bases == 1: 

393 self._cache[basis_number_key] = 0 

394 continue 

395 in_prior_range = np.ones(number_of_bases, dtype=bool) 

396 prior_range_key = f'prior_range_{basis_type}' 

397 for param_name in self.weights[prior_range_key]: 

398 if param_name not in parameters: 

399 continue 

400 in_prior_range *= \ 

401 (self.weights[prior_range_key][param_name][:, 0] <= parameters[param_name]) * \ 

402 (self.weights[prior_range_key][param_name][:, 1] >= parameters[param_name]) 

403 self._cache[basis_number_key] = np.arange(number_of_bases)[in_prior_range][0] 

404 basis_number_linear = self._cache['basis_number_linear'] 

405 basis_number_quadratic = self._cache['basis_number_quadratic'] 

406 frequency_nodes, linear_indices, quadratic_indices = \ 

407 self._unique_frequency_nodes_and_inverse[basis_number_linear][basis_number_quadratic] 

408 self._waveform_generator.waveform_arguments['frequency_nodes'] = frequency_nodes 

409 self._waveform_generator.waveform_arguments['linear_indices'] = linear_indices 

410 self._waveform_generator.waveform_arguments['quadratic_indices'] = quadratic_indices 

411 self._cache['parameters'] = self.parameters.copy() 

412 

413 @property 

414 def basis_number_linear(self): 

415 if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1: 

416 if self.parameters != self._cache['parameters']: 

417 self._update_basis() 

418 return self._cache['basis_number_linear'] 

419 else: 

420 return 0 

421 

422 @property 

423 def basis_number_quadratic(self): 

424 if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1: 

425 if self.parameters != self._cache['parameters']: 

426 self._update_basis() 

427 return self._cache['basis_number_quadratic'] 

428 else: 

429 return 0 

430 

431 @property 

432 def waveform_generator(self): 

433 if getattr(self, 'number_of_bases_linear', 1) > 1 or getattr(self, 'number_of_bases_quadratic', 1) > 1: 

434 if self.parameters != self._cache['parameters']: 

435 self._update_basis() 

436 return self._waveform_generator 

437 

438 @waveform_generator.setter 

439 def waveform_generator(self, waveform_generator): 

440 self._waveform_generator = waveform_generator 

441 

442 def calculate_snrs(self, waveform_polarizations, interferometer, return_array=True): 

443 """ 

444 Compute the snrs for ROQ 

445 

446 Parameters 

447 ========== 

448 waveform_polarizations: waveform 

449 interferometer: bilby.gw.detector.Interferometer 

450 

451 """ 

452 if self.time_marginalization: 

453 time_ref = self._beam_pattern_reference_time 

454 else: 

455 time_ref = self.parameters['geocent_time'] 

456 

457 frequency_nodes = self.waveform_generator.waveform_arguments['frequency_nodes'] 

458 linear_indices = self.waveform_generator.waveform_arguments['linear_indices'] 

459 quadratic_indices = self.waveform_generator.waveform_arguments['quadratic_indices'] 

460 size_linear = len(linear_indices) 

461 size_quadratic = len(quadratic_indices) 

462 h_linear = np.zeros(size_linear, dtype=complex) 

463 h_quadratic = np.zeros(size_quadratic, dtype=complex) 

464 for mode in waveform_polarizations['linear']: 

465 response = interferometer.antenna_response( 

466 self.parameters['ra'], self.parameters['dec'], 

467 time_ref, 

468 self.parameters['psi'], 

469 mode 

470 ) 

471 h_linear += waveform_polarizations['linear'][mode] * response 

472 h_quadratic += waveform_polarizations['quadratic'][mode] * response 

473 

474 calib_factor = interferometer.calibration_model.get_calibration_factor( 

475 frequency_nodes, prefix='recalib_{}_'.format(interferometer.name), **self.parameters) 

476 h_linear *= calib_factor[linear_indices] 

477 h_quadratic *= calib_factor[quadratic_indices] 

478 

479 optimal_snr_squared = np.vdot( 

480 np.abs(h_quadratic)**2, 

481 self.weights[interferometer.name + '_quadratic'][self.basis_number_quadratic] 

482 ) 

483 

484 dt = interferometer.time_delay_from_geocenter( 

485 self.parameters['ra'], self.parameters['dec'], time_ref) 

486 dt_geocent = self.parameters['geocent_time'] - interferometer.strain_data.start_time 

487 ifo_time = dt_geocent + dt 

488 

489 indices, in_bounds = self._closest_time_indices( 

490 ifo_time, self.weights['time_samples']) 

491 if not in_bounds: 

492 logger.debug("SNR calculation error: requested time at edge of ROQ time samples") 

493 d_inner_h = -np.inf 

494 complex_matched_filter_snr = -np.inf 

495 else: 

496 d_inner_h_tc_array = np.einsum( 

497 'i,ji->j', np.conjugate(h_linear), 

498 self.weights[interferometer.name + '_linear'][self.basis_number_linear][indices]) 

499 

500 d_inner_h = self._interp_five_samples( 

501 self.weights['time_samples'][indices], d_inner_h_tc_array, ifo_time) 

502 

503 with np.errstate(invalid="ignore"): 

504 complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) 

505 

506 if return_array and self.time_marginalization: 

507 ifo_times = self._times - interferometer.strain_data.start_time 

508 ifo_times += dt 

509 if self.jitter_time: 

510 ifo_times += self.parameters['time_jitter'] 

511 d_inner_h_array = self._calculate_d_inner_h_array(ifo_times, h_linear, interferometer.name) 

512 else: 

513 d_inner_h_array = None 

514 

515 return self._CalculatedSNRs( 

516 d_inner_h=d_inner_h, 

517 optimal_snr_squared=optimal_snr_squared.real, 

518 complex_matched_filter_snr=complex_matched_filter_snr, 

519 d_inner_h_array=d_inner_h_array, 

520 ) 

521 

522 @staticmethod 

523 def _closest_time_indices(time, samples): 

524 """ 

525 Get the closest five times 

526 

527 Parameters 

528 ========== 

529 time: float 

530 Time to check 

531 samples: array-like 

532 Available times 

533 

534 Returns 

535 ======= 

536 indices: list 

537 Indices nearest to time 

538 in_bounds: bool 

539 Whether the indices are for valid times 

540 """ 

541 closest = int((time - samples[0]) / (samples[1] - samples[0])) 

542 indices = [closest + ii for ii in [-2, -1, 0, 1, 2]] 

543 in_bounds = (indices[0] >= 0) & (indices[-1] < samples.size) 

544 return indices, in_bounds 

545 

546 @staticmethod 

547 def _interp_five_samples(time_samples, values, time): 

548 """ 

549 Interpolate a function of time with its values at the closest five times. 

550 The algorithm is explained in https://dcc.ligo.org/T2100224. 

551 

552 Parameters 

553 ========== 

554 time_samples: array-like 

555 Closest 5 times 

556 values: array-like 

557 The values of the function at closest 5 times 

558 time: float 

559 Time at which the function is calculated 

560 

561 Returns 

562 ======= 

563 value: float 

564 The value of the function at the input time 

565 """ 

566 r1 = (-values[0] + 8. * values[1] - 14. * values[2] + 8. * values[3] - values[4]) / 4. 

567 r2 = values[2] - 2. * values[3] + values[4] 

568 a = (time_samples[3] - time) / (time_samples[1] - time_samples[0]) 

569 b = 1. - a 

570 c = (a**3. - a) / 6. 

571 d = (b**3. - b) / 6. 

572 return a * values[2] + b * values[3] + c * r1 + d * r2 

573 

574 def _calculate_d_inner_h_array(self, times, h_linear, ifo_name): 

575 """ 

576 Calculate d_inner_h at regularly-spaced time samples. Each value is 

577 interpolated from the nearest 5 samples with the algorithm explained in 

578 https://dcc.ligo.org/T2100224. 

579 

580 Parameters 

581 ========== 

582 times: array-like 

583 Regularly-spaced time samples at which d_inner_h are calculated. 

584 h_linear: array-like 

585 Waveforms at linear frequency nodes 

586 ifo_name: str 

587 

588 Returns 

589 ======= 

590 d_inner_h_array: array-like 

591 """ 

592 roq_time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] 

593 times_per_roq_time_space = (times - self.weights['time_samples'][0]) / roq_time_space 

594 closest_idxs = np.floor(times_per_roq_time_space).astype(int) 

595 # Get the nearest 5 samples of d_inner_h. Calculate only the required d_inner_h values if the time 

596 # spacing is larger than 5 times the ROQ time spacing. 

597 weights_linear = self.weights[ifo_name + '_linear'][self.basis_number_linear] 

598 h_linear_conj = np.conjugate(h_linear) 

599 if (times[1] - times[0]) / roq_time_space > 5: 

600 d_inner_h_m2 = np.dot(weights_linear[closest_idxs - 2], h_linear_conj) 

601 d_inner_h_m1 = np.dot(weights_linear[closest_idxs - 1], h_linear_conj) 

602 d_inner_h_0 = np.dot(weights_linear[closest_idxs], h_linear_conj) 

603 d_inner_h_p1 = np.dot(weights_linear[closest_idxs + 1], h_linear_conj) 

604 d_inner_h_p2 = np.dot(weights_linear[closest_idxs + 2], h_linear_conj) 

605 else: 

606 d_inner_h_at_roq_time_samples = np.dot(weights_linear, h_linear_conj) 

607 d_inner_h_m2 = d_inner_h_at_roq_time_samples[closest_idxs - 2] 

608 d_inner_h_m1 = d_inner_h_at_roq_time_samples[closest_idxs - 1] 

609 d_inner_h_0 = d_inner_h_at_roq_time_samples[closest_idxs] 

610 d_inner_h_p1 = d_inner_h_at_roq_time_samples[closest_idxs + 1] 

611 d_inner_h_p2 = d_inner_h_at_roq_time_samples[closest_idxs + 2] 

612 # quantities required for spline interpolation 

613 b = times_per_roq_time_space - closest_idxs 

614 a = 1. - b 

615 c = (a**3. - a) / 6. 

616 d = (b**3. - b) / 6. 

617 r1 = (-d_inner_h_m2 + 8. * d_inner_h_m1 - 14. * d_inner_h_0 + 8. * d_inner_h_p1 - d_inner_h_p2) / 4. 

618 r2 = d_inner_h_0 - 2. * d_inner_h_p1 + d_inner_h_p2 

619 return a * d_inner_h_0 + b * d_inner_h_p1 + c * r1 + d * r2 

620 

621 def perform_roq_params_check(self, ifo=None): 

622 """ Perform checking that the prior and data are valid for the ROQ 

623 

624 Parameters 

625 ========== 

626 ifo: bilby.gw.detector.Interferometer 

627 The interferometer 

628 """ 

629 if self.roq_params_check is False: 

630 logger.warning("No ROQ params checking performed") 

631 return 

632 else: 

633 if getattr(self, "roq_params_file", None) is not None: 

634 msg = ("Check ROQ params {} with roq_scale_factor={}" 

635 .format(self.roq_params_file, self.roq_scale_factor)) 

636 else: 

637 msg = ("Check ROQ params with roq_scale_factor={}" 

638 .format(self.roq_scale_factor)) 

639 logger.info(msg) 

640 

641 roq_params = self.roq_params 

642 roq_minimum_frequency = roq_params['flow'] * self.roq_scale_factor 

643 roq_maximum_frequency = roq_params['fhigh'] * self.roq_scale_factor 

644 roq_segment_length = roq_params['seglen'] / self.roq_scale_factor 

645 try: 

646 roq_minimum_chirp_mass = roq_params['chirpmassmin'] / self.roq_scale_factor 

647 except ValueError: 

648 roq_minimum_chirp_mass = None 

649 try: 

650 roq_maximum_chirp_mass = roq_params['chirpmassmax'] / self.roq_scale_factor 

651 except ValueError: 

652 roq_maximum_chirp_mass = None 

653 try: 

654 roq_minimum_component_mass = roq_params['compmin'] / self.roq_scale_factor 

655 except ValueError: 

656 roq_minimum_component_mass = None 

657 

658 if ifo.maximum_frequency > roq_maximum_frequency: 

659 raise BilbyROQParamsRangeError( 

660 "Requested maximum frequency {} larger than ROQ basis fhigh {}" 

661 .format(ifo.maximum_frequency, roq_maximum_frequency) 

662 ) 

663 if ifo.minimum_frequency < roq_minimum_frequency: 

664 raise BilbyROQParamsRangeError( 

665 "Requested minimum frequency {} lower than ROQ basis flow {}" 

666 .format(ifo.minimum_frequency, roq_minimum_frequency) 

667 ) 

668 if ifo.strain_data.duration != roq_segment_length: 

669 raise BilbyROQParamsRangeError( 

670 "Requested duration differs from ROQ basis seglen") 

671 

672 priors = self.priors 

673 if isinstance(priors, CBCPriorDict) is False: 

674 logger.warning("Unable to check ROQ parameter bounds: priors not understood") 

675 return 

676 

677 if roq_minimum_chirp_mass is not None: 

678 if priors.minimum_chirp_mass is None: 

679 logger.warning("Unable to check minimum chirp mass ROQ bounds") 

680 elif priors.minimum_chirp_mass < roq_minimum_chirp_mass: 

681 raise BilbyROQParamsRangeError( 

682 "Prior minimum chirp mass {} less than ROQ basis bound {}" 

683 .format(priors.minimum_chirp_mass, roq_minimum_chirp_mass) 

684 ) 

685 

686 if roq_maximum_chirp_mass is not None: 

687 if priors.maximum_chirp_mass is None: 

688 logger.warning("Unable to check maximum_chirp mass ROQ bounds") 

689 elif priors.maximum_chirp_mass > roq_maximum_chirp_mass: 

690 raise BilbyROQParamsRangeError( 

691 "Prior maximum chirp mass {} greater than ROQ basis bound {}" 

692 .format(priors.maximum_chirp_mass, roq_maximum_chirp_mass) 

693 ) 

694 

695 if roq_minimum_component_mass is not None: 

696 if priors.minimum_component_mass is None: 

697 logger.warning("Unable to check minimum component mass ROQ bounds") 

698 elif priors.minimum_component_mass < roq_minimum_component_mass: 

699 raise BilbyROQParamsRangeError( 

700 "Prior minimum component mass {} less than ROQ basis bound {}" 

701 .format(priors.minimum_component_mass, roq_minimum_component_mass) 

702 ) 

703 

704 def _set_weights(self, linear_matrix, quadratic_matrix): 

705 """ 

706 Setup the time-dependent ROQ weights. 

707 

708 Parameters 

709 ========== 

710 linear_matrix, quadratic_matrix: dictionary or h5py.File 

711 linear and quadratic basis 

712 

713 """ 

714 time_space = self._get_time_resolution() 

715 number_of_time_samples = int(self.interferometers.duration / time_space) 

716 earth_light_crossing_time = 2 * radius_of_earth / speed_of_light + 5 * time_space 

717 start_idx = max( 

718 0, 

719 int(np.floor(( 

720 self.priors['{}_time'.format(self.time_reference)].minimum 

721 - earth_light_crossing_time 

722 - self.interferometers.start_time 

723 ) / time_space)) 

724 ) 

725 end_idx = min( 

726 number_of_time_samples - 1, 

727 int(np.ceil(( 

728 self.priors['{}_time'.format(self.time_reference)].maximum 

729 + earth_light_crossing_time 

730 - self.interferometers.start_time 

731 ) / time_space)) 

732 ) 

733 self.weights['time_samples'] = np.arange(start_idx, end_idx + 1) * time_space 

734 logger.info("Using {} ROQ time samples".format(len(self.weights['time_samples']))) 

735 

736 # select bases to be used, set prior ranges and frequency nodes if exist 

737 idxs_in_prior_range = dict() 

738 for basis_type, matrix in zip(['linear', 'quadratic'], [linear_matrix, quadratic_matrix]): 

739 key = f'prior_range_{basis_type}' 

740 if key in matrix: 

741 prior_ranges = {} 

742 for param_name in matrix[key]: 

743 if 'roq_scale_power' in matrix[key][param_name].attrs: 

744 roq_scale_factor = self.roq_scale_factor**matrix[key][param_name].attrs['roq_scale_power'] 

745 else: 

746 roq_scale_factor = 1. 

747 prior_ranges[param_name] = matrix[key][param_name][()] * roq_scale_factor 

748 selected_idxs, selected_prior_ranges = self._select_prior_ranges(prior_ranges) 

749 if len(selected_idxs) == 0: 

750 raise BilbyROQParamsRangeError(f"There are no {basis_type} ROQ bases within the prior range.") 

751 self.weights[key] = selected_prior_ranges 

752 idxs_in_prior_range[basis_type] = selected_idxs 

753 else: 

754 idxs_in_prior_range[basis_type] = [0] 

755 if 'frequency_nodes' in matrix[f'basis_{basis_type}'][str(idxs_in_prior_range[basis_type][0])]: 

756 self.weights[f'frequency_nodes_{basis_type}'] = [ 

757 matrix[f'basis_{basis_type}'][str(i)]['frequency_nodes'][()] * self.roq_scale_factor 

758 for i in idxs_in_prior_range[basis_type]] 

759 

760 if 'multiband_linear' in linear_matrix: 

761 multiband_linear = linear_matrix['multiband_linear'][()] 

762 else: 

763 multiband_linear = False 

764 if 'multiband_quadratic' in quadratic_matrix: 

765 multiband_quadratic = quadratic_matrix['multiband_quadratic'][()] 

766 else: 

767 multiband_quadratic = False 

768 

769 # Get intersection between ifo and ROQ frequency samples. Required only for non-multibanded basis. 

770 if not (multiband_linear and multiband_quadratic): 

771 roq_idxs = {} 

772 ifo_idxs = {} 

773 for ifo in self.interferometers: 

774 if self.roq_params is not None: 

775 # Get scaled ROQ quantities 

776 roq_scaled_minimum_frequency = self.roq_params['flow'] * self.roq_scale_factor 

777 roq_scaled_maximum_frequency = self.roq_params['fhigh'] * self.roq_scale_factor 

778 roq_scaled_segment_length = self.roq_params['seglen'] / self.roq_scale_factor 

779 # Generate frequencies for the ROQ 

780 roq_frequencies = create_frequency_series( 

781 sampling_frequency=roq_scaled_maximum_frequency * 2, 

782 duration=roq_scaled_segment_length) 

783 roq_mask = roq_frequencies >= roq_scaled_minimum_frequency 

784 roq_frequencies = roq_frequencies[roq_mask] 

785 overlap_frequencies, ifo_idxs_this_ifo, roq_idxs_this_ifo = np.intersect1d( 

786 ifo.frequency_array[ifo.frequency_mask], roq_frequencies, 

787 return_indices=True) 

788 else: 

789 overlap_frequencies = ifo.frequency_array[ifo.frequency_mask] 

790 roq_idxs_this_ifo = np.arange( 

791 linear_matrix['basis_linear'][str(idxs_in_prior_range['linear'][0])]['basis'].shape[1], 

792 dtype=int) 

793 ifo_idxs_this_ifo = np.arange(sum(ifo.frequency_mask)) 

794 if len(ifo_idxs_this_ifo) != len(roq_idxs_this_ifo): 

795 raise ValueError( 

796 "Mismatch between ROQ basis and frequency array for " 

797 "{}".format(ifo.name)) 

798 logger.info( 

799 "Building ROQ weights for {} with {} frequencies between {} " 

800 "and {}.".format( 

801 ifo.name, len(overlap_frequencies), 

802 min(overlap_frequencies), max(overlap_frequencies))) 

803 roq_idxs[ifo.name] = roq_idxs_this_ifo 

804 ifo_idxs[ifo.name] = ifo_idxs_this_ifo 

805 

806 if multiband_linear: 

807 self._set_weights_linear_multiband(linear_matrix, idxs_in_prior_range['linear']) 

808 else: 

809 self._set_weights_linear(linear_matrix, idxs_in_prior_range['linear'], roq_idxs, ifo_idxs) 

810 

811 if multiband_quadratic: 

812 self._set_weights_quadratic_multiband(quadratic_matrix, idxs_in_prior_range['quadratic']) 

813 else: 

814 self._set_weights_quadratic(quadratic_matrix, idxs_in_prior_range['quadratic'], roq_idxs, ifo_idxs) 

815 

816 def _set_weights_linear(self, linear_matrix, basis_idxs, roq_idxs, ifo_idxs): 

817 """ 

818 Setup the time-dependent linear ROQ weights. See https://dcc.ligo.org/LIGO-T2100125 for the detail of how to 

819 compute them. 

820 

821 Parameters 

822 ========== 

823 linear_matrix : dictionary or h5py.File 

824 linear basis 

825 basis_idxs : array-like 

826 indexes of bases used for a run 

827 roq_idxs : dictionary 

828 dictionary whose keys are interferometer names and values are indexes of basis components intersecting 

829 frequency-domain data 

830 ifo_idxs : dictionary 

831 dictionary whose keys are interferometer names and values are indexes of frequency-domain data intersecting 

832 basis components 

833 

834 """ 

835 for ifo in self.interferometers: 

836 self.weights[ifo.name + '_linear'] = [] 

837 time_space = self.weights['time_samples'][1] - self.weights['time_samples'][0] 

838 number_of_time_samples = int(self.interferometers.duration / time_space) 

839 start_idx = int(self.weights['time_samples'][0] / time_space) 

840 end_idx = int(self.weights['time_samples'][-1] / time_space) 

841 nonzero_idxs = {} 

842 data_over_psd = {} 

843 for ifo in self.interferometers: 

844 nonzero_idxs[ifo.name] = ifo_idxs[ifo.name] + int( 

845 ifo.frequency_array[ifo.frequency_mask][0] * self.interferometers.duration) 

846 data_over_psd[ifo.name] = ifo.frequency_domain_strain[ifo.frequency_mask][ifo_idxs[ifo.name]] / \ 

847 ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]] 

848 try: 

849 import pyfftw 

850 ifft_input = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) 

851 ifft_output = pyfftw.empty_aligned(number_of_time_samples, dtype=complex) 

852 ifft = pyfftw.FFTW(ifft_input, ifft_output, direction='FFTW_BACKWARD') 

853 except ImportError: 

854 pyfftw = None 

855 logger.warning("You do not have pyfftw installed, falling back to numpy.fft.") 

856 ifft_input = np.zeros(number_of_time_samples, dtype=complex) 

857 ifft = np.fft.ifft 

858 for basis_idx in basis_idxs: 

859 logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") 

860 linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'] 

861 basis_size = linear_matrix_single.shape[0] 

862 for ifo in self.interferometers: 

863 ifft_input[:] *= 0. 

864 linear_weights = \ 

865 np.zeros((len(self.weights['time_samples']), basis_size), dtype=complex) 

866 for i in range(basis_size): 

867 basis_element = linear_matrix_single[i][roq_idxs[ifo.name]] 

868 ifft_input[nonzero_idxs[ifo.name]] = data_over_psd[ifo.name] * np.conj(basis_element) 

869 linear_weights[:, i] = ifft(ifft_input)[start_idx:end_idx + 1] 

870 linear_weights *= 4. * number_of_time_samples / self.interferometers.duration 

871 self.weights[ifo.name + '_linear'].append(linear_weights) 

872 if pyfftw is not None: 

873 pyfftw.forget_wisdom() 

874 

875 def _set_weights_linear_multiband(self, linear_matrix, basis_idxs): 

876 """ 

877 Setup the time-dependent linear ROQ weights from multibanded basis 

878 

879 Parameters 

880 ========== 

881 linear_matrix : dictionary or h5py.File 

882 linear basis 

883 basis_idxs : array-like 

884 indexes of bases used for a run 

885 

886 """ 

887 for ifo in self.interferometers: 

888 self.weights[ifo.name + '_linear'] = [] 

889 Tbs = linear_matrix['durations_s_linear'][()] / self.roq_scale_factor 

890 start_end_frequency_bins = linear_matrix['start_end_frequency_bins_linear'][()] 

891 basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) 

892 fhigh_basis = np.max(start_end_frequency_bins[:, 1] / Tbs) 

893 # prepare time-shifted data, which is multiplied by basis 

894 tc_shifted_data = dict() 

895 for ifo in self.interferometers: 

896 over_whitened_frequency_data = np.zeros(int(fhigh_basis * ifo.duration) + 1, dtype=complex) 

897 over_whitened_frequency_data[np.arange(len(ifo.frequency_domain_strain))[ifo.frequency_mask]] = \ 

898 ifo.frequency_domain_strain[ifo.frequency_mask] / ifo.power_spectral_density_array[ifo.frequency_mask] 

899 over_whitened_time_data = np.fft.irfft(over_whitened_frequency_data) 

900 tc_shifted_data[ifo.name] = np.zeros((basis_dimension, len(self.weights['time_samples'])), dtype=complex) 

901 start_idx_of_band = 0 

902 for b, Tb in enumerate(Tbs): 

903 start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] 

904 fs = np.arange(start_frequency_bin, end_frequency_bin + 1) / Tb 

905 Db = np.fft.rfft( 

906 over_whitened_time_data[-int(2. * fhigh_basis * Tb):] 

907 )[start_frequency_bin:end_frequency_bin + 1] 

908 start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 

909 tc_shifted_data[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * Db[:, None] * np.exp( 

910 2. * np.pi * 1j * fs[:, None] * (self.weights['time_samples'][None, :] - ifo.duration + Tb)) 

911 start_idx_of_band = start_idx_of_next_band 

912 # compute inner products 

913 for basis_idx in basis_idxs: 

914 logger.info(f"Building linear ROQ weights for the {basis_idx}-th basis.") 

915 linear_matrix_single = linear_matrix['basis_linear'][str(basis_idx)]['basis'][()] 

916 for ifo in self.interferometers: 

917 self.weights[ifo.name + '_linear'].append( 

918 np.dot(np.conj(linear_matrix_single), tc_shifted_data[ifo.name]).T) 

919 

920 def _set_weights_quadratic(self, quadratic_matrix, basis_idxs, roq_idxs, ifo_idxs): 

921 """ 

922 Setup the quadratic ROQ weights 

923 

924 Parameters 

925 ========== 

926 quadratic_matrix : dictionary or h5py.File 

927 quadratic basis 

928 basis_idxs : array-like 

929 indexes of bases used for a run 

930 roq_idxs : dictionary 

931 dictionary whose keys are interferometer names and values are indexes of basis components intersecting 

932 frequency-domain data 

933 ifo_idxs : dictionary 

934 dictionary whose keys are interferometer names and values are indexes of frequency-domain data intersecting 

935 basis components 

936 

937 """ 

938 for ifo in self.interferometers: 

939 self.weights[ifo.name + '_quadratic'] = [] 

940 for basis_idx in basis_idxs: 

941 logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") 

942 quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real 

943 for ifo in self.interferometers: 

944 self.weights[ifo.name + '_quadratic'].append( 

945 4. / ifo.strain_data.duration * np.dot( 

946 quadratic_matrix_single[:, roq_idxs[ifo.name]], 

947 1 / ifo.power_spectral_density_array[ifo.frequency_mask][ifo_idxs[ifo.name]])) 

948 del quadratic_matrix_single 

949 

950 def _set_weights_quadratic_multiband(self, quadratic_matrix, basis_idxs): 

951 """ 

952 Setup the quadratic ROQ weights from multibanded basis 

953 

954 Parameters 

955 ========== 

956 quadratic_matrix : dictionary or h5py.File 

957 quadratic basis 

958 basis_idxs : array-like 

959 indexes of bases used for a run 

960 

961 """ 

962 for ifo in self.interferometers: 

963 self.weights[ifo.name + '_quadratic'] = [] 

964 Tbs = quadratic_matrix['durations_s_quadratic'][()] / self.roq_scale_factor 

965 start_end_frequency_bins = quadratic_matrix['start_end_frequency_bins_quadratic'][()] 

966 basis_dimension = np.sum(start_end_frequency_bins[:, 1] - start_end_frequency_bins[:, 0] + 1) 

967 fhigh_basis = np.max(start_end_frequency_bins[:, 1] / Tbs) 

968 # prepare coefficients multiplied by basis 

969 multibanded_inverse_psd = dict() 

970 for ifo in self.interferometers: 

971 inverse_psd_frequency = np.zeros(int(fhigh_basis * ifo.duration) + 1) 

972 inverse_psd_frequency[np.arange(len(ifo.power_spectral_density_array))[ifo.frequency_mask]] = \ 

973 1. / ifo.power_spectral_density_array[ifo.frequency_mask] 

974 inverse_psd_time = np.fft.irfft(inverse_psd_frequency) 

975 multibanded_inverse_psd[ifo.name] = np.zeros(basis_dimension) 

976 start_idx_of_band = 0 

977 for b, Tb in enumerate(Tbs): 

978 start_frequency_bin, end_frequency_bin = start_end_frequency_bins[b] 

979 number_of_samples_half = int(fhigh_basis * Tb) 

980 start_idx_of_next_band = start_idx_of_band + end_frequency_bin - start_frequency_bin + 1 

981 multibanded_inverse_psd[ifo.name][start_idx_of_band:start_idx_of_next_band] = 4. / Tb * np.fft.rfft( 

982 np.append(inverse_psd_time[:number_of_samples_half], inverse_psd_time[-number_of_samples_half:]) 

983 )[start_frequency_bin:end_frequency_bin + 1].real 

984 start_idx_of_band = start_idx_of_next_band 

985 # compute inner products 

986 for basis_idx in basis_idxs: 

987 logger.info(f"Building quadratic ROQ weights for the {basis_idx}-th basis.") 

988 quadratic_matrix_single = quadratic_matrix['basis_quadratic'][str(basis_idx)]['basis'][()].real 

989 for ifo in self.interferometers: 

990 self.weights[ifo.name + '_quadratic'].append( 

991 np.dot(quadratic_matrix_single, multibanded_inverse_psd[ifo.name])) 

992 

993 def save_weights(self, filename, format='hdf5'): 

994 """ 

995 Save ROQ weights into a single file. format should be npz, or hdf5. 

996 For weights from multiple bases, hdf5 is only the possible option. 

997 Support for json format is deprecated as of :code:`v2.1` and will be 

998 removed in :code:`v2.2`, another method should be used by default. 

999 

1000 Parameters 

1001 ========== 

1002 filename : str 

1003 The name of the file to save the weights to. 

1004 format : str 

1005 The format to save the data to, this should be one of 

1006 :code:`"hdf5"`, :code:`"npz"`, default=:code:`"hdf5"`. 

1007 """ 

1008 if format not in ['json', 'npz', 'hdf5']: 

1009 raise IOError(f"Format {format} not recognized.") 

1010 if format == "json": 

1011 import warnings 

1012 

1013 warnings.warn( 

1014 "json format for ROQ weights is deprecated, use hdf5 instead.", 

1015 DeprecationWarning 

1016 ) 

1017 if format not in filename: 

1018 filename += "." + format 

1019 logger.info(f"Saving ROQ weights to {filename}") 

1020 if format == 'json' or format == 'npz': 

1021 if self.number_of_bases_linear > 1 or self.number_of_bases_quadratic > 1: 

1022 raise ValueError(f'Format {format} not compatible with multiple bases') 

1023 weights = dict() 

1024 weights['time_samples'] = self.weights['time_samples'] 

1025 for basis_type in ['linear', 'quadratic']: 

1026 for ifo in self.interferometers: 

1027 key = f'{ifo.name}_{basis_type}' 

1028 weights[key] = self.weights[key][0] 

1029 if format == 'json': 

1030 with open(filename, 'w') as file: 

1031 json.dump(weights, file, indent=2, cls=BilbyJsonEncoder) 

1032 else: 

1033 np.savez(filename, **weights) 

1034 else: 

1035 import h5py 

1036 with h5py.File(filename, 'w') as f: 

1037 f.create_dataset('time_samples', 

1038 data=self.weights['time_samples']) 

1039 for basis_type in ['linear', 'quadratic']: 

1040 key = f'prior_range_{basis_type}' 

1041 if key in self.weights: 

1042 grp = f.create_group(key) 

1043 for param_name in self.weights[key]: 

1044 grp.create_dataset( 

1045 param_name, data=self.weights[key][param_name]) 

1046 key = f'frequency_nodes_{basis_type}' 

1047 if key in self.weights: 

1048 grp = f.create_group(key) 

1049 for i in range(len(self.weights[key])): 

1050 grp.create_dataset( 

1051 str(i), data=self.weights[key][i]) 

1052 for ifo in self.interferometers: 

1053 key = f"{ifo.name}_{basis_type}" 

1054 grp = f.create_group(key) 

1055 for i in range(len(self.weights[key])): 

1056 grp.create_dataset( 

1057 str(i), data=self.weights[key][i]) 

1058 

1059 def load_weights(self, filename, format=None): 

1060 """ 

1061 Load ROQ weights. format should be json, npz, or hdf5. 

1062 json or npz file is assumed to contain weights from a single basis. 

1063 Support for json format is deprecated as of :code:`v2.1` and will be 

1064 removed in :code:`v2.2`, another method should be used by default. 

1065 

1066 Parameters 

1067 ========== 

1068 filename : str 

1069 The name of the file to save the weights to. 

1070 format : str 

1071 The format to save the data to, this should be one of 

1072 :code:`"hdf5"`, :code:`"npz"`, default=:code:`"hdf5"`. 

1073 

1074 Returns 

1075 ======= 

1076 weights: dict 

1077 Dictionary containing the ROQ weights. 

1078 """ 

1079 if format is None: 

1080 format = filename.split(".")[-1] 

1081 if format not in ["json", "npz", "hdf5"]: 

1082 raise IOError(f"Format {format} not recognized.") 

1083 if format == "json": 

1084 import warnings 

1085 

1086 warnings.warn( 

1087 "json format for ROQ weights is deprecated, use hdf5 instead.", 

1088 DeprecationWarning 

1089 ) 

1090 logger.info(f"Loading ROQ weights from {filename}") 

1091 if format == "json" or format == "npz": 

1092 # Old file format assumed to contain only a single basis 

1093 if format == "json": 

1094 with open(filename, 'r') as file: 

1095 weights = json.load(file, object_hook=decode_bilby_json) 

1096 else: 

1097 # Wrap in dict to load data into memory 

1098 weights = dict(np.load(filename)) 

1099 for basis_type in ['linear', 'quadratic']: 

1100 for ifo in self.interferometers: 

1101 key = f'{ifo.name}_{basis_type}' 

1102 weights[key] = [weights[key]] 

1103 else: 

1104 weights = dict() 

1105 import h5py 

1106 with h5py.File(filename, 'r') as f: 

1107 weights['time_samples'] = f['time_samples'][()] 

1108 for basis_type in ['linear', 'quadratic']: 

1109 key = f'prior_range_{basis_type}' 

1110 if key in f: 

1111 idxs_in_prior_range, selected_prior_ranges = \ 

1112 self._select_prior_ranges(f[key]) 

1113 weights[key] = selected_prior_ranges 

1114 else: 

1115 idxs_in_prior_range = [0] 

1116 key = f'frequency_nodes_{basis_type}' 

1117 if key in f: 

1118 weights[key] = [f[key][str(i)][()] 

1119 for i in idxs_in_prior_range] 

1120 for ifo in self.interferometers: 

1121 key = f"{ifo.name}_{basis_type}" 

1122 weights[key] = [f[key][str(i)][()] 

1123 for i in idxs_in_prior_range] 

1124 return weights 

1125 

1126 def _get_time_resolution(self): 

1127 """ 

1128 This method estimates the time resolution given the optimal SNR of the 

1129 signal in the detector. This is then used when constructing the weights 

1130 for the ROQ. 

1131 

1132 A minimum resolution is set by assuming the SNR in each detector is at 

1133 least 10. When the SNR is not available the SNR is assumed to be 30 in 

1134 each detector. 

1135 

1136 Returns 

1137 ======= 

1138 delta_t: float 

1139 Time resolution 

1140 """ 

1141 

1142 def calc_fhigh(freq, psd, scaling=20.): 

1143 """ 

1144 

1145 Parameters 

1146 ========== 

1147 freq: array-like 

1148 Frequency array 

1149 psd: array-like 

1150 Power spectral density 

1151 scaling: float 

1152 SNR dependent scaling factor 

1153 

1154 Returns 

1155 ======= 

1156 f_high: float 

1157 The maximum frequency which must be considered 

1158 """ 

1159 from scipy.integrate import simpson 

1160 integrand1 = np.power(freq, -7. / 3) / psd 

1161 integral1 = simpson(y=integrand1, x=freq) 

1162 integrand3 = np.power(freq, 2. / 3.) / (psd * integral1) 

1163 f_3_bar = simpson(y=integrand3, x=freq) 

1164 

1165 f_high = scaling * f_3_bar**(1 / 3) 

1166 

1167 return f_high 

1168 

1169 def c_f_scaling(snr): 

1170 return (np.pi**2 * snr**2 / 6)**(1 / 3) 

1171 

1172 inj_snr_sq = 0 

1173 for ifo in self.interferometers: 

1174 inj_snr_sq += max(10, ifo.meta_data.get('optimal_SNR', 30))**2 

1175 

1176 psd = ifo.power_spectral_density_array[ifo.frequency_mask] 

1177 freq = ifo.frequency_array[ifo.frequency_mask] 

1178 fhigh = calc_fhigh(freq, psd, scaling=c_f_scaling(inj_snr_sq**0.5)) 

1179 

1180 delta_t = fhigh**-1 

1181 

1182 # Apply a safety factor to ensure the time step is short enough 

1183 delta_t = delta_t / 5 

1184 

1185 # duration / delta_t needs to be a power of 2 for IFFT 

1186 number_of_time_samples = max( 

1187 self.interferometers.duration / delta_t, 

1188 self.interferometers.frequency_array[-1] * self.interferometers.duration + 1) 

1189 number_of_time_samples = int(2**np.ceil(np.log2(number_of_time_samples))) 

1190 delta_t = self.interferometers.duration / number_of_time_samples 

1191 logger.info("ROQ time-step = {}".format(delta_t)) 

1192 return delta_t 

1193 

1194 def _rescale_signal(self, signal, new_distance): 

1195 for kind in ['linear', 'quadratic']: 

1196 for mode in signal[kind]: 

1197 signal[kind][mode] *= self._ref_dist / new_distance 

1198 

1199 def generate_time_sample_from_marginalized_likelihood(self, signal_polarizations=None): 

1200 from ...core.utils.random import rng 

1201 

1202 self.parameters.update(self.get_sky_frame_parameters()) 

1203 if signal_polarizations is None: 

1204 signal_polarizations = \ 

1205 self.waveform_generator.frequency_domain_strain(self.parameters) 

1206 

1207 snrs = self._CalculatedSNRs() 

1208 

1209 for interferometer in self.interferometers: 

1210 snrs += self.calculate_snrs( 

1211 waveform_polarizations=signal_polarizations, 

1212 interferometer=interferometer 

1213 ) 

1214 d_inner_h = snrs.d_inner_h_array 

1215 h_inner_h = snrs.optimal_snr_squared 

1216 

1217 if self.distance_marginalization: 

1218 time_log_like = self.distance_marginalized_likelihood( 

1219 d_inner_h, h_inner_h) 

1220 elif self.phase_marginalization: 

1221 time_log_like = ln_i0(abs(d_inner_h)) - h_inner_h.real / 2 

1222 else: 

1223 time_log_like = (d_inner_h.real - h_inner_h.real / 2) 

1224 

1225 times = self._times 

1226 if self.jitter_time: 

1227 times = times + self.parameters["time_jitter"] 

1228 time_prior_array = self.priors['geocent_time'].prob(times) 

1229 time_post = np.exp(time_log_like - max(time_log_like)) * time_prior_array 

1230 time_post /= np.sum(time_post) 

1231 return rng.choice(times, p=time_post) 

1232 

1233 

1234class BilbyROQParamsRangeError(Exception): 

1235 pass