Coverage for bilby/gw/result.py: 21%

362 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import json 

2import os 

3import pickle 

4 

5import numpy as np 

6 

7from ..core.result import Result as CoreResult 

8from ..core.utils import ( 

9 infft, logger, check_directory_exists_and_if_not_mkdir, 

10 latex_plot_format, safe_file_dump, safe_save_figure, 

11) 

12from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series 

13from .detector import get_empty_interferometer, Interferometer 

14 

15 

16class CompactBinaryCoalescenceResult(CoreResult): 

17 """ 

18 Result class with additional methods and attributes specific to analyses 

19 of compact binaries. 

20 """ 

21 def __init__(self, **kwargs): 

22 super(CompactBinaryCoalescenceResult, self).__init__(**kwargs) 

23 

24 def __get_from_nested_meta_data(self, *keys): 

25 dictionary = self.meta_data 

26 try: 

27 item = None 

28 for k in keys: 

29 item = dictionary[k] 

30 dictionary = item 

31 return item 

32 except KeyError: 

33 raise AttributeError( 

34 "No information stored for {}".format('/'.join(keys))) 

35 

36 @property 

37 def sampling_frequency(self): 

38 """ Sampling frequency in Hertz""" 

39 return self.__get_from_nested_meta_data( 

40 'likelihood', 'sampling_frequency') 

41 

42 @property 

43 def duration(self): 

44 """ Duration in seconds """ 

45 return self.__get_from_nested_meta_data( 

46 'likelihood', 'duration') 

47 

48 @property 

49 def start_time(self): 

50 """ Start time in seconds """ 

51 return self.__get_from_nested_meta_data( 

52 'likelihood', 'start_time') 

53 

54 @property 

55 def time_marginalization(self): 

56 """ Boolean for if the likelihood used time marginalization """ 

57 return self.__get_from_nested_meta_data( 

58 'likelihood', 'time_marginalization') 

59 

60 @property 

61 def phase_marginalization(self): 

62 """ Boolean for if the likelihood used phase marginalization """ 

63 return self.__get_from_nested_meta_data( 

64 'likelihood', 'phase_marginalization') 

65 

66 @property 

67 def distance_marginalization(self): 

68 """ Boolean for if the likelihood used distance marginalization """ 

69 return self.__get_from_nested_meta_data( 

70 'likelihood', 'distance_marginalization') 

71 

72 @property 

73 def interferometers(self): 

74 """ List of interferometer names """ 

75 return [name for name in self.__get_from_nested_meta_data( 

76 'likelihood', 'interferometers')] 

77 

78 @property 

79 def waveform_approximant(self): 

80 """ String of the waveform approximant """ 

81 return self.__get_from_nested_meta_data( 

82 'likelihood', 'waveform_arguments', 'waveform_approximant') 

83 

84 @property 

85 def waveform_generator_class(self): 

86 """ Dict of waveform arguments """ 

87 return self.__get_from_nested_meta_data( 

88 'likelihood', 'waveform_generator_class') 

89 

90 @property 

91 def waveform_arguments(self): 

92 """ Dict of waveform arguments """ 

93 return self.__get_from_nested_meta_data( 

94 'likelihood', 'waveform_arguments') 

95 

96 @property 

97 def reference_frequency(self): 

98 """ Float of the reference frequency """ 

99 return self.__get_from_nested_meta_data( 

100 'likelihood', 'waveform_arguments', 'reference_frequency') 

101 

102 @property 

103 def frequency_domain_source_model(self): 

104 """ The frequency domain source model (function)""" 

105 return self.__get_from_nested_meta_data( 

106 'likelihood', 'frequency_domain_source_model') 

107 

108 @property 

109 def time_domain_source_model(self): 

110 """ The time domain source model (function)""" 

111 return self.__get_from_nested_meta_data( 

112 'likelihood', 'time_domain_source_model') 

113 

114 @property 

115 def parameter_conversion(self): 

116 """ The frequency domain source model (function)""" 

117 return self.__get_from_nested_meta_data( 

118 'likelihood', 'parameter_conversion') 

119 

120 def detector_injection_properties(self, detector): 

121 """ Returns a dictionary of the injection properties for each detector 

122 

123 The injection properties include the parameters injected, and 

124 information about the signal to noise ratio (SNR) given the noise 

125 properties. 

126 

127 Parameters 

128 ========== 

129 detector: str [H1, L1, V1] 

130 Detector name 

131 

132 Returns 

133 ======= 

134 injection_properties: dict 

135 A dictionary of the injection properties 

136 

137 """ 

138 try: 

139 return self.__get_from_nested_meta_data( 

140 'likelihood', 'interferometers', detector) 

141 except AttributeError: 

142 logger.info("No injection for detector {}".format(detector)) 

143 return None 

144 

145 @latex_plot_format 

146 def plot_calibration_posterior(self, level=.9, format="png"): 

147 """ Plots the calibration amplitude and phase uncertainty. 

148 Adapted from the LALInference version in bayespputils 

149 

150 Plot is saved to {self.outdir}/{self.label}_calibration.{format} 

151 

152 Parameters 

153 ========== 

154 level: float 

155 Quantile for confidence levels, default=0.9, i.e., 90% interval 

156 format: str 

157 Format to save the plot, default=png, options are png/pdf 

158 """ 

159 import matplotlib.pyplot as plt 

160 if format not in ["png", "pdf"]: 

161 raise ValueError("Format should be one of png or pdf") 

162 

163 fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(15, 15), dpi=500) 

164 posterior = self.posterior 

165 

166 font_size = 32 

167 outdir = self.outdir 

168 

169 parameters = posterior.keys() 

170 ifos = np.unique([param.split('_')[1] for param in parameters if 'recalib_' in param]) 

171 if ifos.size == 0: 

172 logger.info("No calibration parameters found. Aborting calibration plot.") 

173 return 

174 

175 for ifo in ifos: 

176 if ifo == 'H1': 

177 color = 'r' 

178 elif ifo == 'L1': 

179 color = 'g' 

180 elif ifo == 'V1': 

181 color = 'm' 

182 else: 

183 color = 'c' 

184 

185 # Assume spline control frequencies are constant 

186 freq_params = np.sort([param for param in parameters if 

187 'recalib_{0}_frequency_'.format(ifo) in param]) 

188 

189 logfreqs = np.log([posterior[param].iloc[0] for param in freq_params]) 

190 

191 # Amplitude calibration model 

192 plt.sca(ax1) 

193 amp_params = np.sort([param for param in parameters if 

194 'recalib_{0}_amplitude_'.format(ifo) in param]) 

195 if len(amp_params) > 0: 

196 amplitude = 100 * np.column_stack([posterior[param] for param in amp_params]) 

197 plot_spline_pos(logfreqs, amplitude, color=color, level=level, 

198 label=r"{0} (mean, {1}$\%$)".format(ifo.upper(), int(level * 100))) 

199 

200 # Phase calibration model 

201 plt.sca(ax2) 

202 phase_params = np.sort([param for param in parameters if 

203 'recalib_{0}_phase_'.format(ifo) in param]) 

204 if len(phase_params) > 0: 

205 phase = np.column_stack([posterior[param] for param in phase_params]) 

206 plot_spline_pos(logfreqs, phase, color=color, level=level, 

207 label=r"{0} (mean, {1}$\%$)".format(ifo.upper(), int(level * 100)), 

208 xform=spline_angle_xform) 

209 

210 ax1.tick_params(labelsize=.75 * font_size) 

211 ax2.tick_params(labelsize=.75 * font_size) 

212 plt.legend(loc='upper right', prop={'size': .75 * font_size}, framealpha=0.1) 

213 ax1.set_xscale('log') 

214 ax2.set_xscale('log') 

215 

216 ax2.set_xlabel('Frequency [Hz]', fontsize=font_size) 

217 ax1.set_ylabel(r'Amplitude [$\%$]', fontsize=font_size) 

218 ax2.set_ylabel('Phase [deg]', fontsize=font_size) 

219 

220 filename = os.path.join(outdir, self.label + '_calibration.' + format) 

221 fig.tight_layout() 

222 safe_save_figure( 

223 fig=fig, filename=filename, 

224 format=format, dpi=600, bbox_inches='tight' 

225 ) 

226 logger.debug("Calibration figure saved to {}".format(filename)) 

227 plt.close() 

228 

229 def plot_waveform_posterior( 

230 self, interferometers=None, level=0.9, n_samples=None, 

231 format='png', start_time=None, end_time=None): 

232 """ 

233 Plot the posterior for the waveform in the frequency domain and 

234 whitened time domain for all detectors. 

235 

236 If the strain data is passed that will be plotted. 

237 

238 If injection parameters can be found, the injection will be plotted. 

239 

240 Parameters 

241 ========== 

242 interferometers: (list, bilby.gw.detector.InterferometerList, optional) 

243 level: float, optional 

244 symmetric confidence interval to show, default is 90% 

245 n_samples: int, optional 

246 number of samples to use to calculate the median/interval 

247 default is all 

248 format: str, optional 

249 format to save the figure in, default is png 

250 start_time: float, optional 

251 the amount of time before merger to begin the time domain plot. 

252 the merger time is defined as the mean of the geocenter time 

253 posterior. Default is - 0.4 

254 end_time: float, optional 

255 the amount of time before merger to end the time domain plot. 

256 the merger time is defined as the mean of the geocenter time 

257 posterior. Default is 0.2 

258 """ 

259 if interferometers is None: 

260 interferometers = self.interferometers 

261 elif not isinstance(interferometers, list): 

262 raise TypeError( 

263 'interferometers must be a list or InterferometerList') 

264 for ifo in interferometers: 

265 self.plot_interferometer_waveform_posterior( 

266 interferometer=ifo, level=level, n_samples=n_samples, 

267 save=True, format=format, start_time=start_time, 

268 end_time=end_time) 

269 

270 @latex_plot_format 

271 def plot_interferometer_waveform_posterior( 

272 self, interferometer, level=0.9, n_samples=None, save=True, 

273 format='png', start_time=None, end_time=None): 

274 """ 

275 Plot the posterior for the waveform in the frequency domain and 

276 whitened time domain. 

277 

278 If the strain data is passed that will be plotted. 

279 

280 If injection parameters can be found, the injection will be plotted. 

281 

282 Parameters 

283 ========== 

284 interferometer: (str, bilby.gw.detector.interferometer.Interferometer) 

285 detector to use, if an Interferometer object is passed the data 

286 will be overlaid on the posterior 

287 level: float, optional 

288 symmetric confidence interval to show, default is 90% 

289 n_samples: int, optional 

290 number of samples to use to calculate the median/interval 

291 default is all 

292 save: bool, optional 

293 whether to save the image, default=True 

294 if False, figure handle is returned 

295 format: str, optional 

296 format to save the figure in, default is png 

297 start_time: float, optional 

298 the amount of time before merger to begin the time domain plot. 

299 the merger time is defined as the mean of the geocenter time 

300 posterior. Default is - 0.4 

301 end_time: float, optional 

302 the amount of time before merger to end the time domain plot. 

303 the merger time is defined as the mean of the geocenter time 

304 posterior. Default is 0.2 

305 

306 Returns 

307 ======= 

308 fig: figure-handle, only is save=False 

309 

310 Notes 

311 ----- 

312 To reduce the memory footprint we decimate the frequency domain 

313 waveforms to have ~4000 entries. This should be sufficient for decent 

314 resolution. 

315 """ 

316 

317 DATA_COLOR = "#ff7f0e" 

318 WAVEFORM_COLOR = "#1f77b4" 

319 INJECTION_COLOR = "#000000" 

320 

321 if format == "html": 

322 try: 

323 import plotly.graph_objects as go 

324 from plotly.offline import plot 

325 from plotly.subplots import make_subplots 

326 except ImportError: 

327 logger.warning( 

328 "HTML plotting requested, but plotly cannot be imported, " 

329 "falling back to png format for waveform plot.") 

330 format = "png" 

331 

332 if isinstance(interferometer, str): 

333 interferometer = get_empty_interferometer(interferometer) 

334 interferometer.set_strain_data_from_zero_noise( 

335 sampling_frequency=self.sampling_frequency, 

336 duration=self.duration, start_time=self.start_time) 

337 PLOT_DATA = False 

338 elif not isinstance(interferometer, Interferometer): 

339 raise TypeError( 

340 'interferometer must be either str or Interferometer') 

341 else: 

342 PLOT_DATA = True 

343 logger.info("Generating waveform figure for {}".format( 

344 interferometer.name)) 

345 

346 if n_samples is None: 

347 samples = self.posterior 

348 elif n_samples > len(self.posterior): 

349 logger.debug( 

350 "Requested more waveform samples ({}) than we have " 

351 "posterior samples ({})!".format( 

352 n_samples, len(self.posterior) 

353 ) 

354 ) 

355 samples = self.posterior 

356 else: 

357 samples = self.posterior.sample(n_samples, replace=False) 

358 

359 if start_time is None: 

360 start_time = - 0.4 

361 start_time = np.mean(samples.geocent_time) + start_time 

362 if end_time is None: 

363 end_time = 0.2 

364 end_time = np.mean(samples.geocent_time) + end_time 

365 if format == "html": 

366 start_time = - np.inf 

367 end_time = np.inf 

368 time_idxs = ( 

369 (interferometer.time_array >= start_time) & 

370 (interferometer.time_array <= end_time) 

371 ) 

372 frequency_idxs = np.where(interferometer.frequency_mask)[0] 

373 logger.debug("Frequency mask contains {} values".format( 

374 len(frequency_idxs)) 

375 ) 

376 frequency_idxs = frequency_idxs[::max(1, len(frequency_idxs) // 4000)] 

377 logger.debug("Downsampling frequency mask to {} values".format( 

378 len(frequency_idxs)) 

379 ) 

380 plot_times = interferometer.time_array[time_idxs] 

381 plot_times -= interferometer.strain_data.start_time 

382 start_time -= interferometer.strain_data.start_time 

383 end_time -= interferometer.strain_data.start_time 

384 plot_frequencies = interferometer.frequency_array[frequency_idxs] 

385 

386 waveform_generator = self.waveform_generator_class( 

387 duration=self.duration, sampling_frequency=self.sampling_frequency, 

388 start_time=self.start_time, 

389 frequency_domain_source_model=self.frequency_domain_source_model, 

390 time_domain_source_model=self.time_domain_source_model, 

391 parameter_conversion=self.parameter_conversion, 

392 waveform_arguments=self.waveform_arguments) 

393 

394 if format == "html": 

395 fig = make_subplots( 

396 rows=2, cols=1, 

397 row_heights=[0.5, 0.5], 

398 ) 

399 fig.update_layout( 

400 template='plotly_white', 

401 font=dict( 

402 family="Computer Modern", 

403 ) 

404 ) 

405 else: 

406 import matplotlib.pyplot as plt 

407 from matplotlib import rcParams 

408 old_font_size = rcParams["font.size"] 

409 rcParams["font.size"] = 20 

410 fig, axs = plt.subplots( 

411 2, 1, 

412 gridspec_kw=dict(height_ratios=[1.5, 1]), 

413 figsize=(16, 12.5) 

414 ) 

415 

416 if PLOT_DATA: 

417 if format == "html": 

418 fig.add_trace( 

419 go.Scatter( 

420 x=plot_frequencies, 

421 y=asd_from_freq_series( 

422 interferometer.frequency_domain_strain[frequency_idxs], 

423 1 / interferometer.strain_data.duration 

424 ), 

425 fill=None, 

426 mode='lines', line_color=DATA_COLOR, 

427 opacity=0.5, 

428 name="Data", 

429 legendgroup='data', 

430 ), 

431 row=1, 

432 col=1, 

433 ) 

434 fig.add_trace( 

435 go.Scatter( 

436 x=plot_frequencies, 

437 y=interferometer.amplitude_spectral_density_array[frequency_idxs], 

438 fill=None, 

439 mode='lines', line_color=DATA_COLOR, 

440 opacity=0.8, 

441 name="ASD", 

442 legendgroup='asd', 

443 ), 

444 row=1, 

445 col=1, 

446 ) 

447 fig.add_trace( 

448 go.Scatter( 

449 x=plot_times, 

450 y=interferometer.whitened_time_domain_strain[time_idxs], 

451 fill=None, 

452 mode='lines', line_color=DATA_COLOR, 

453 opacity=0.5, 

454 name="Data", 

455 legendgroup='data', 

456 showlegend=False, 

457 ), 

458 row=2, 

459 col=1, 

460 ) 

461 else: 

462 axs[0].loglog( 

463 plot_frequencies, 

464 asd_from_freq_series( 

465 interferometer.frequency_domain_strain[frequency_idxs], 

466 1 / interferometer.strain_data.duration), 

467 color=DATA_COLOR, label='Data', alpha=0.3) 

468 axs[0].loglog( 

469 plot_frequencies, 

470 interferometer.amplitude_spectral_density_array[frequency_idxs], 

471 color=DATA_COLOR, label='ASD') 

472 axs[1].plot( 

473 plot_times, interferometer.whitened_time_domain_strain[time_idxs], 

474 color=DATA_COLOR, alpha=0.3) 

475 logger.debug('Plotted interferometer data.') 

476 

477 fd_waveforms = list() 

478 td_waveforms = list() 

479 for _, params in samples.iterrows(): 

480 params = dict(params) 

481 wf_pols = waveform_generator.frequency_domain_strain(params) 

482 fd_waveform = interferometer.get_detector_response(wf_pols, params) 

483 fd_waveforms.append(fd_waveform[frequency_idxs]) 

484 whitened_fd_waveform = interferometer.whiten_frequency_series(fd_waveform) 

485 td_waveform = interferometer.get_whitened_time_series_from_whitened_frequency_series( 

486 whitened_fd_waveform 

487 )[time_idxs] 

488 td_waveforms.append(td_waveform) 

489 fd_waveforms = asd_from_freq_series( 

490 fd_waveforms, 

491 1 / interferometer.strain_data.duration) 

492 td_waveforms = np.array(td_waveforms) 

493 

494 delta = (1 + level) / 2 

495 upper_percentile = delta * 100 

496 lower_percentile = (1 - delta) * 100 

497 logger.debug( 

498 'Plotting posterior between the {} and {} percentiles'.format( 

499 lower_percentile, upper_percentile 

500 ) 

501 ) 

502 

503 if format == "html": 

504 fig.add_trace( 

505 go.Scatter( 

506 x=plot_frequencies, y=np.median(fd_waveforms, axis=0), 

507 fill=None, 

508 mode='lines', line_color=WAVEFORM_COLOR, 

509 opacity=1, 

510 name="Median reconstructed", 

511 legendgroup='median', 

512 ), 

513 row=1, 

514 col=1, 

515 ) 

516 fig.add_trace( 

517 go.Scatter( 

518 x=plot_frequencies, y=np.percentile(fd_waveforms, lower_percentile, axis=0), 

519 fill=None, 

520 mode='lines', 

521 line_color=WAVEFORM_COLOR, 

522 opacity=0.1, 

523 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), 

524 legendgroup='uncertainty', 

525 ), 

526 row=1, 

527 col=1, 

528 ) 

529 fig.add_trace( 

530 go.Scatter( 

531 x=plot_frequencies, y=np.percentile(fd_waveforms, upper_percentile, axis=0), 

532 fill='tonexty', 

533 mode='lines', 

534 line_color=WAVEFORM_COLOR, 

535 opacity=0.1, 

536 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), 

537 legendgroup='uncertainty', 

538 showlegend=False, 

539 ), 

540 row=1, 

541 col=1, 

542 ) 

543 fig.add_trace( 

544 go.Scatter( 

545 x=plot_times, y=np.median(td_waveforms, axis=0), 

546 fill=None, 

547 mode='lines', line_color=WAVEFORM_COLOR, 

548 opacity=1, 

549 name="Median reconstructed", 

550 legendgroup='median', 

551 showlegend=False, 

552 ), 

553 row=2, 

554 col=1, 

555 ) 

556 fig.add_trace( 

557 go.Scatter( 

558 x=plot_times, y=np.percentile(td_waveforms, lower_percentile, axis=0), 

559 fill=None, 

560 mode='lines', 

561 line_color=WAVEFORM_COLOR, 

562 opacity=0.1, 

563 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), 

564 legendgroup='uncertainty', 

565 showlegend=False, 

566 ), 

567 row=2, 

568 col=1, 

569 ) 

570 fig.add_trace( 

571 go.Scatter( 

572 x=plot_times, y=np.percentile(td_waveforms, upper_percentile, axis=0), 

573 fill='tonexty', 

574 mode='lines', 

575 line_color=WAVEFORM_COLOR, 

576 opacity=0.1, 

577 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile), 

578 legendgroup='uncertainty', 

579 showlegend=False, 

580 ), 

581 row=2, 

582 col=1, 

583 ) 

584 else: 

585 lower_limit = np.mean(fd_waveforms, axis=0)[0] / 1e3 

586 axs[0].loglog( 

587 plot_frequencies, 

588 np.mean(fd_waveforms, axis=0), color=WAVEFORM_COLOR, label='Mean reconstructed') 

589 axs[0].fill_between( 

590 plot_frequencies, 

591 np.percentile(fd_waveforms, lower_percentile, axis=0), 

592 np.percentile(fd_waveforms, upper_percentile, axis=0), 

593 color=WAVEFORM_COLOR, 

594 label=r'{}% credible interval'.format(int(upper_percentile - lower_percentile)), 

595 alpha=0.3) 

596 axs[1].plot( 

597 plot_times, np.mean(td_waveforms, axis=0), 

598 color=WAVEFORM_COLOR) 

599 axs[1].fill_between( 

600 plot_times, np.percentile( 

601 td_waveforms, lower_percentile, axis=0), 

602 np.percentile(td_waveforms, upper_percentile, axis=0), 

603 color=WAVEFORM_COLOR, 

604 alpha=0.3) 

605 

606 if self.injection_parameters is not None: 

607 try: 

608 hf_inj = waveform_generator.frequency_domain_strain( 

609 self.injection_parameters) 

610 hf_inj_det = interferometer.get_detector_response( 

611 hf_inj, self.injection_parameters) 

612 ht_inj_det = infft( 

613 hf_inj_det * np.sqrt(2. / interferometer.sampling_frequency) / 

614 interferometer.amplitude_spectral_density_array, 

615 self.sampling_frequency)[time_idxs] 

616 if format == "html": 

617 fig.add_trace( 

618 go.Scatter( 

619 x=plot_frequencies, 

620 y=asd_from_freq_series( 

621 hf_inj_det[frequency_idxs], 

622 1 / interferometer.strain_data.duration), 

623 fill=None, 

624 mode='lines', 

625 line=dict(color=INJECTION_COLOR, dash='dot'), 

626 name="Injection", 

627 legendgroup='injection', 

628 ), 

629 row=1, 

630 col=1, 

631 ) 

632 fig.add_trace( 

633 go.Scatter( 

634 x=plot_times, y=ht_inj_det, 

635 fill=None, 

636 mode='lines', 

637 line=dict(color=INJECTION_COLOR, dash='dot'), 

638 name="Injection", 

639 legendgroup='injection', 

640 showlegend=False, 

641 ), 

642 row=2, 

643 col=1, 

644 ) 

645 else: 

646 axs[0].loglog( 

647 plot_frequencies, 

648 asd_from_freq_series( 

649 hf_inj_det[frequency_idxs], 

650 1 / interferometer.strain_data.duration), 

651 color=INJECTION_COLOR, label='Injection', linestyle=':') 

652 axs[1].plot( 

653 plot_times, ht_inj_det, 

654 color=INJECTION_COLOR, linestyle=':') 

655 logger.debug('Plotted injection.') 

656 except IndexError as e: 

657 logger.info('Failed to plot injection with message {}.'.format(e)) 

658 

659 f_domain_x_label = "$f [\\mathrm{Hz}]$" 

660 f_domain_y_label = "$\\mathrm{ASD} \\left[\\mathrm{Hz}^{-1/2}\\right]$" 

661 t_domain_x_label = "$t - {} [s]$".format(interferometer.strain_data.start_time) 

662 t_domain_y_label = "Whitened Strain" 

663 if format == "html": 

664 fig.update_xaxes(title_text=f_domain_x_label, type="log", row=1) 

665 fig.update_yaxes(title_text=f_domain_y_label, type="log", row=1) 

666 fig.update_xaxes(title_text=t_domain_x_label, type="linear", row=2) 

667 fig.update_yaxes(title_text=t_domain_y_label, type="linear", row=2) 

668 else: 

669 axs[0].set_xlim(interferometer.minimum_frequency, 

670 interferometer.maximum_frequency) 

671 axs[1].set_xlim(start_time, end_time) 

672 axs[0].set_ylim(lower_limit) 

673 axs[0].set_xlabel(f_domain_x_label) 

674 axs[0].set_ylabel(f_domain_y_label) 

675 axs[1].set_xlabel(t_domain_x_label) 

676 axs[1].set_ylabel(t_domain_y_label) 

677 axs[0].legend(loc='lower left', ncol=2) 

678 

679 if save: 

680 filename = os.path.join( 

681 self.outdir, 

682 self.label + '_{}_waveform.{}'.format( 

683 interferometer.name, format)) 

684 if format == 'html': 

685 plot(fig, filename=filename, include_mathjax='cdn', auto_open=False) 

686 else: 

687 plt.tight_layout() 

688 safe_save_figure( 

689 fig=fig, filename=filename, 

690 format=format, dpi=600 

691 ) 

692 plt.close() 

693 logger.debug("Waveform figure saved to {}".format(filename)) 

694 rcParams["font.size"] = old_font_size 

695 else: 

696 rcParams["font.size"] = old_font_size 

697 return fig 

698 

699 def plot_skymap( 

700 self, maxpts=None, trials=5, jobs=1, enable_multiresolution=True, 

701 objid=None, instruments=None, geo=False, dpi=600, 

702 transparent=False, colorbar=False, contour=[50, 90], 

703 annotate=True, cmap='cylon', load_pickle=False): 

704 """ Generate a fits file and sky map from a result 

705 

706 Code adapted from ligo.skymap.tool.ligo_skymap_from_samples and 

707 ligo.skymap.tool.plot_skymap. Note, the use of this additionally 

708 required the installation of ligo.skymap. 

709 

710 Parameters 

711 ========== 

712 maxpts: int 

713 Maximum number of samples to use, if None all samples are used 

714 trials: int 

715 Number of trials at each clustering number 

716 jobs: int 

717 Number of multiple threads 

718 enable_multiresolution: bool 

719 Generate a multiresolution HEALPix map (default: True) 

720 objid: str 

721 Event ID to store in FITS header 

722 instruments: str 

723 Name of detectors 

724 geo: bool 

725 Plot in geographic coordinates (lat, lon) instead of RA, Dec 

726 dpi: int 

727 Resolution of figure in fots per inch 

728 transparent: bool 

729 Save image with transparent background 

730 colorbar: bool 

731 Show colorbar 

732 contour: list 

733 List of contour levels to use 

734 annotate: bool 

735 Annotate image with details 

736 cmap: str 

737 Name of the colormap to use 

738 load_pickle: bool, str 

739 If true, load the cached pickle file (default name), or the 

740 pickle-file give as a path. 

741 """ 

742 import matplotlib.pyplot as plt 

743 from matplotlib import rcParams 

744 

745 try: 

746 from astropy.time import Time 

747 from ligo.skymap import io, version, plot, postprocess, bayestar, kde 

748 import healpy as hp 

749 except ImportError as e: 

750 logger.info("Unable to generate skymap: error {}".format(e)) 

751 return 

752 

753 check_directory_exists_and_if_not_mkdir(self.outdir) 

754 

755 logger.info('Reading samples for skymap') 

756 data = self.posterior 

757 

758 if maxpts is not None and maxpts < len(data): 

759 logger.info('Taking random subsample of chain') 

760 data = data.sample(maxpts) 

761 

762 default_obj_filename = os.path.join(self.outdir, '{}_skypost.obj'.format(self.label)) 

763 

764 if load_pickle is False: 

765 try: 

766 pts = data[['ra', 'dec', 'luminosity_distance']].values 

767 confidence_levels = kde.Clustered2Plus1DSkyKDE 

768 distance = True 

769 except KeyError: 

770 logger.warning("The results file does not contain luminosity_distance") 

771 pts = data[['ra', 'dec']].values 

772 confidence_levels = kde.Clustered2DSkyKDE 

773 distance = False 

774 

775 logger.info('Initialising skymap class') 

776 skypost = confidence_levels(pts, trials=trials, jobs=jobs) 

777 logger.info('Pickling skymap to {}'.format(default_obj_filename)) 

778 safe_file_dump(skypost, default_obj_filename, "pickle") 

779 

780 else: 

781 if isinstance(load_pickle, str): 

782 obj_filename = load_pickle 

783 else: 

784 obj_filename = default_obj_filename 

785 logger.info('Reading from pickle {}'.format(obj_filename)) 

786 with open(obj_filename, 'rb') as file: 

787 skypost = pickle.load(file) 

788 skypost.jobs = jobs 

789 distance = isinstance(skypost, kde.Clustered2Plus1DSkyKDE) 

790 

791 logger.info('Making skymap') 

792 hpmap = skypost.as_healpix() 

793 if not enable_multiresolution: 

794 hpmap = bayestar.rasterize(hpmap) 

795 

796 hpmap.meta.update(io.fits.metadata_for_version_module(version)) 

797 hpmap.meta['creator'] = "bilby" 

798 hpmap.meta['origin'] = 'LIGO/Virgo' 

799 hpmap.meta['gps_creation_time'] = Time.now().gps 

800 hpmap.meta['history'] = "" 

801 if objid is not None: 

802 hpmap.meta['objid'] = objid 

803 if instruments: 

804 hpmap.meta['instruments'] = instruments 

805 if distance: 

806 hpmap.meta['distmean'] = np.mean(data['luminosity_distance']) 

807 hpmap.meta['diststd'] = np.std(data['luminosity_distance']) 

808 

809 try: 

810 time = data['geocent_time'] 

811 hpmap.meta['gps_time'] = time.mean() 

812 except KeyError: 

813 logger.warning('Cannot determine the event time from geocent_time') 

814 

815 fits_filename = os.path.join(self.outdir, "{}_skymap.fits".format(self.label)) 

816 logger.info('Saving skymap fits-file to {}'.format(fits_filename)) 

817 io.write_sky_map(fits_filename, hpmap, nest=True) 

818 

819 skymap, metadata = io.fits.read_sky_map(fits_filename, nest=None) 

820 nside = hp.npix2nside(len(skymap)) 

821 

822 # Convert sky map from probability to probability per square degree. 

823 deg2perpix = hp.nside2pixarea(nside, degrees=True) 

824 probperdeg2 = skymap / deg2perpix 

825 

826 if geo: 

827 obstime = Time(metadata['gps_time'], format='gps').utc.isot 

828 ax = plt.axes(projection='geo degrees mollweide', obstime=obstime) 

829 else: 

830 ax = plt.axes(projection='astro hours mollweide') 

831 ax.grid() 

832 

833 # Plot sky map. 

834 vmax = probperdeg2.max() 

835 img = ax.imshow_hpx( 

836 (probperdeg2, 'ICRS'), nested=metadata['nest'], vmin=0., vmax=vmax, 

837 cmap=cmap) 

838 

839 # Add colorbar. 

840 if colorbar: 

841 cb = plot.colorbar(img) 

842 cb.set_label(r'prob. per deg$^2$') 

843 

844 if contour is not None: 

845 confidence_levels = 100 * postprocess.find_greedy_credible_levels(skymap) 

846 contours = ax.contour_hpx( 

847 (confidence_levels, 'ICRS'), nested=metadata['nest'], 

848 colors='k', linewidths=0.5, levels=contour) 

849 fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%' 

850 plt.clabel(contours, fmt=fmt, fontsize=6, inline=True) 

851 

852 # Add continents. 

853 if geo: 

854 geojson_filename = os.path.join( 

855 os.path.dirname(plot.__file__), 'ne_simplified_coastline.json') 

856 with open(geojson_filename, 'r') as geojson_file: 

857 geoms = json.load(geojson_file)['geometries'] 

858 verts = [coord for geom in geoms 

859 for coord in zip(*geom['coordinates'])] 

860 plt.plot(*verts, color='0.5', linewidth=0.5, 

861 transform=ax.get_transform('world')) 

862 

863 # Add a white outline to all text to make it stand out from the background. 

864 plot.outline_text(ax) 

865 

866 if annotate: 

867 text = [] 

868 try: 

869 objid = metadata['objid'] 

870 except KeyError: 

871 pass 

872 else: 

873 text.append('event ID: {}'.format(objid)) 

874 if contour: 

875 pp = np.round(contour).astype(int) 

876 ii = np.round(np.searchsorted(np.sort(confidence_levels), contour) * 

877 deg2perpix).astype(int) 

878 for i, p in zip(ii, pp): 

879 text.append( 

880 u'{:d}% area: {:d} deg$^2$'.format(p, i)) 

881 ax.text(1, 1, '\n'.join(text), transform=ax.transAxes, ha='right') 

882 

883 filename = os.path.join(self.outdir, "{}_skymap.png".format(self.label)) 

884 logger.info("Generating 2D projected skymap to {}".format(filename)) 

885 safe_save_figure(fig=plt.gcf(), filename=filename, dpi=dpi) 

886 

887 

888CBCResult = CompactBinaryCoalescenceResult