Coverage for bilby/gw/result.py: 21%
362 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import json
2import os
3import pickle
5import numpy as np
7from ..core.result import Result as CoreResult
8from ..core.utils import (
9 infft, logger, check_directory_exists_and_if_not_mkdir,
10 latex_plot_format, safe_file_dump, safe_save_figure,
11)
12from .utils import plot_spline_pos, spline_angle_xform, asd_from_freq_series
13from .detector import get_empty_interferometer, Interferometer
16class CompactBinaryCoalescenceResult(CoreResult):
17 """
18 Result class with additional methods and attributes specific to analyses
19 of compact binaries.
20 """
21 def __init__(self, **kwargs):
22 super(CompactBinaryCoalescenceResult, self).__init__(**kwargs)
24 def __get_from_nested_meta_data(self, *keys):
25 dictionary = self.meta_data
26 try:
27 item = None
28 for k in keys:
29 item = dictionary[k]
30 dictionary = item
31 return item
32 except KeyError:
33 raise AttributeError(
34 "No information stored for {}".format('/'.join(keys)))
36 @property
37 def sampling_frequency(self):
38 """ Sampling frequency in Hertz"""
39 return self.__get_from_nested_meta_data(
40 'likelihood', 'sampling_frequency')
42 @property
43 def duration(self):
44 """ Duration in seconds """
45 return self.__get_from_nested_meta_data(
46 'likelihood', 'duration')
48 @property
49 def start_time(self):
50 """ Start time in seconds """
51 return self.__get_from_nested_meta_data(
52 'likelihood', 'start_time')
54 @property
55 def time_marginalization(self):
56 """ Boolean for if the likelihood used time marginalization """
57 return self.__get_from_nested_meta_data(
58 'likelihood', 'time_marginalization')
60 @property
61 def phase_marginalization(self):
62 """ Boolean for if the likelihood used phase marginalization """
63 return self.__get_from_nested_meta_data(
64 'likelihood', 'phase_marginalization')
66 @property
67 def distance_marginalization(self):
68 """ Boolean for if the likelihood used distance marginalization """
69 return self.__get_from_nested_meta_data(
70 'likelihood', 'distance_marginalization')
72 @property
73 def interferometers(self):
74 """ List of interferometer names """
75 return [name for name in self.__get_from_nested_meta_data(
76 'likelihood', 'interferometers')]
78 @property
79 def waveform_approximant(self):
80 """ String of the waveform approximant """
81 return self.__get_from_nested_meta_data(
82 'likelihood', 'waveform_arguments', 'waveform_approximant')
84 @property
85 def waveform_generator_class(self):
86 """ Dict of waveform arguments """
87 return self.__get_from_nested_meta_data(
88 'likelihood', 'waveform_generator_class')
90 @property
91 def waveform_arguments(self):
92 """ Dict of waveform arguments """
93 return self.__get_from_nested_meta_data(
94 'likelihood', 'waveform_arguments')
96 @property
97 def reference_frequency(self):
98 """ Float of the reference frequency """
99 return self.__get_from_nested_meta_data(
100 'likelihood', 'waveform_arguments', 'reference_frequency')
102 @property
103 def frequency_domain_source_model(self):
104 """ The frequency domain source model (function)"""
105 return self.__get_from_nested_meta_data(
106 'likelihood', 'frequency_domain_source_model')
108 @property
109 def time_domain_source_model(self):
110 """ The time domain source model (function)"""
111 return self.__get_from_nested_meta_data(
112 'likelihood', 'time_domain_source_model')
114 @property
115 def parameter_conversion(self):
116 """ The frequency domain source model (function)"""
117 return self.__get_from_nested_meta_data(
118 'likelihood', 'parameter_conversion')
120 def detector_injection_properties(self, detector):
121 """ Returns a dictionary of the injection properties for each detector
123 The injection properties include the parameters injected, and
124 information about the signal to noise ratio (SNR) given the noise
125 properties.
127 Parameters
128 ==========
129 detector: str [H1, L1, V1]
130 Detector name
132 Returns
133 =======
134 injection_properties: dict
135 A dictionary of the injection properties
137 """
138 try:
139 return self.__get_from_nested_meta_data(
140 'likelihood', 'interferometers', detector)
141 except AttributeError:
142 logger.info("No injection for detector {}".format(detector))
143 return None
145 @latex_plot_format
146 def plot_calibration_posterior(self, level=.9, format="png"):
147 """ Plots the calibration amplitude and phase uncertainty.
148 Adapted from the LALInference version in bayespputils
150 Plot is saved to {self.outdir}/{self.label}_calibration.{format}
152 Parameters
153 ==========
154 level: float
155 Quantile for confidence levels, default=0.9, i.e., 90% interval
156 format: str
157 Format to save the plot, default=png, options are png/pdf
158 """
159 import matplotlib.pyplot as plt
160 if format not in ["png", "pdf"]:
161 raise ValueError("Format should be one of png or pdf")
163 fig, [ax1, ax2] = plt.subplots(2, 1, figsize=(15, 15), dpi=500)
164 posterior = self.posterior
166 font_size = 32
167 outdir = self.outdir
169 parameters = posterior.keys()
170 ifos = np.unique([param.split('_')[1] for param in parameters if 'recalib_' in param])
171 if ifos.size == 0:
172 logger.info("No calibration parameters found. Aborting calibration plot.")
173 return
175 for ifo in ifos:
176 if ifo == 'H1':
177 color = 'r'
178 elif ifo == 'L1':
179 color = 'g'
180 elif ifo == 'V1':
181 color = 'm'
182 else:
183 color = 'c'
185 # Assume spline control frequencies are constant
186 freq_params = np.sort([param for param in parameters if
187 'recalib_{0}_frequency_'.format(ifo) in param])
189 logfreqs = np.log([posterior[param].iloc[0] for param in freq_params])
191 # Amplitude calibration model
192 plt.sca(ax1)
193 amp_params = np.sort([param for param in parameters if
194 'recalib_{0}_amplitude_'.format(ifo) in param])
195 if len(amp_params) > 0:
196 amplitude = 100 * np.column_stack([posterior[param] for param in amp_params])
197 plot_spline_pos(logfreqs, amplitude, color=color, level=level,
198 label=r"{0} (mean, {1}$\%$)".format(ifo.upper(), int(level * 100)))
200 # Phase calibration model
201 plt.sca(ax2)
202 phase_params = np.sort([param for param in parameters if
203 'recalib_{0}_phase_'.format(ifo) in param])
204 if len(phase_params) > 0:
205 phase = np.column_stack([posterior[param] for param in phase_params])
206 plot_spline_pos(logfreqs, phase, color=color, level=level,
207 label=r"{0} (mean, {1}$\%$)".format(ifo.upper(), int(level * 100)),
208 xform=spline_angle_xform)
210 ax1.tick_params(labelsize=.75 * font_size)
211 ax2.tick_params(labelsize=.75 * font_size)
212 plt.legend(loc='upper right', prop={'size': .75 * font_size}, framealpha=0.1)
213 ax1.set_xscale('log')
214 ax2.set_xscale('log')
216 ax2.set_xlabel('Frequency [Hz]', fontsize=font_size)
217 ax1.set_ylabel(r'Amplitude [$\%$]', fontsize=font_size)
218 ax2.set_ylabel('Phase [deg]', fontsize=font_size)
220 filename = os.path.join(outdir, self.label + '_calibration.' + format)
221 fig.tight_layout()
222 safe_save_figure(
223 fig=fig, filename=filename,
224 format=format, dpi=600, bbox_inches='tight'
225 )
226 logger.debug("Calibration figure saved to {}".format(filename))
227 plt.close()
229 def plot_waveform_posterior(
230 self, interferometers=None, level=0.9, n_samples=None,
231 format='png', start_time=None, end_time=None):
232 """
233 Plot the posterior for the waveform in the frequency domain and
234 whitened time domain for all detectors.
236 If the strain data is passed that will be plotted.
238 If injection parameters can be found, the injection will be plotted.
240 Parameters
241 ==========
242 interferometers: (list, bilby.gw.detector.InterferometerList, optional)
243 level: float, optional
244 symmetric confidence interval to show, default is 90%
245 n_samples: int, optional
246 number of samples to use to calculate the median/interval
247 default is all
248 format: str, optional
249 format to save the figure in, default is png
250 start_time: float, optional
251 the amount of time before merger to begin the time domain plot.
252 the merger time is defined as the mean of the geocenter time
253 posterior. Default is - 0.4
254 end_time: float, optional
255 the amount of time before merger to end the time domain plot.
256 the merger time is defined as the mean of the geocenter time
257 posterior. Default is 0.2
258 """
259 if interferometers is None:
260 interferometers = self.interferometers
261 elif not isinstance(interferometers, list):
262 raise TypeError(
263 'interferometers must be a list or InterferometerList')
264 for ifo in interferometers:
265 self.plot_interferometer_waveform_posterior(
266 interferometer=ifo, level=level, n_samples=n_samples,
267 save=True, format=format, start_time=start_time,
268 end_time=end_time)
270 @latex_plot_format
271 def plot_interferometer_waveform_posterior(
272 self, interferometer, level=0.9, n_samples=None, save=True,
273 format='png', start_time=None, end_time=None):
274 """
275 Plot the posterior for the waveform in the frequency domain and
276 whitened time domain.
278 If the strain data is passed that will be plotted.
280 If injection parameters can be found, the injection will be plotted.
282 Parameters
283 ==========
284 interferometer: (str, bilby.gw.detector.interferometer.Interferometer)
285 detector to use, if an Interferometer object is passed the data
286 will be overlaid on the posterior
287 level: float, optional
288 symmetric confidence interval to show, default is 90%
289 n_samples: int, optional
290 number of samples to use to calculate the median/interval
291 default is all
292 save: bool, optional
293 whether to save the image, default=True
294 if False, figure handle is returned
295 format: str, optional
296 format to save the figure in, default is png
297 start_time: float, optional
298 the amount of time before merger to begin the time domain plot.
299 the merger time is defined as the mean of the geocenter time
300 posterior. Default is - 0.4
301 end_time: float, optional
302 the amount of time before merger to end the time domain plot.
303 the merger time is defined as the mean of the geocenter time
304 posterior. Default is 0.2
306 Returns
307 =======
308 fig: figure-handle, only is save=False
310 Notes
311 -----
312 To reduce the memory footprint we decimate the frequency domain
313 waveforms to have ~4000 entries. This should be sufficient for decent
314 resolution.
315 """
317 DATA_COLOR = "#ff7f0e"
318 WAVEFORM_COLOR = "#1f77b4"
319 INJECTION_COLOR = "#000000"
321 if format == "html":
322 try:
323 import plotly.graph_objects as go
324 from plotly.offline import plot
325 from plotly.subplots import make_subplots
326 except ImportError:
327 logger.warning(
328 "HTML plotting requested, but plotly cannot be imported, "
329 "falling back to png format for waveform plot.")
330 format = "png"
332 if isinstance(interferometer, str):
333 interferometer = get_empty_interferometer(interferometer)
334 interferometer.set_strain_data_from_zero_noise(
335 sampling_frequency=self.sampling_frequency,
336 duration=self.duration, start_time=self.start_time)
337 PLOT_DATA = False
338 elif not isinstance(interferometer, Interferometer):
339 raise TypeError(
340 'interferometer must be either str or Interferometer')
341 else:
342 PLOT_DATA = True
343 logger.info("Generating waveform figure for {}".format(
344 interferometer.name))
346 if n_samples is None:
347 samples = self.posterior
348 elif n_samples > len(self.posterior):
349 logger.debug(
350 "Requested more waveform samples ({}) than we have "
351 "posterior samples ({})!".format(
352 n_samples, len(self.posterior)
353 )
354 )
355 samples = self.posterior
356 else:
357 samples = self.posterior.sample(n_samples, replace=False)
359 if start_time is None:
360 start_time = - 0.4
361 start_time = np.mean(samples.geocent_time) + start_time
362 if end_time is None:
363 end_time = 0.2
364 end_time = np.mean(samples.geocent_time) + end_time
365 if format == "html":
366 start_time = - np.inf
367 end_time = np.inf
368 time_idxs = (
369 (interferometer.time_array >= start_time) &
370 (interferometer.time_array <= end_time)
371 )
372 frequency_idxs = np.where(interferometer.frequency_mask)[0]
373 logger.debug("Frequency mask contains {} values".format(
374 len(frequency_idxs))
375 )
376 frequency_idxs = frequency_idxs[::max(1, len(frequency_idxs) // 4000)]
377 logger.debug("Downsampling frequency mask to {} values".format(
378 len(frequency_idxs))
379 )
380 plot_times = interferometer.time_array[time_idxs]
381 plot_times -= interferometer.strain_data.start_time
382 start_time -= interferometer.strain_data.start_time
383 end_time -= interferometer.strain_data.start_time
384 plot_frequencies = interferometer.frequency_array[frequency_idxs]
386 waveform_generator = self.waveform_generator_class(
387 duration=self.duration, sampling_frequency=self.sampling_frequency,
388 start_time=self.start_time,
389 frequency_domain_source_model=self.frequency_domain_source_model,
390 time_domain_source_model=self.time_domain_source_model,
391 parameter_conversion=self.parameter_conversion,
392 waveform_arguments=self.waveform_arguments)
394 if format == "html":
395 fig = make_subplots(
396 rows=2, cols=1,
397 row_heights=[0.5, 0.5],
398 )
399 fig.update_layout(
400 template='plotly_white',
401 font=dict(
402 family="Computer Modern",
403 )
404 )
405 else:
406 import matplotlib.pyplot as plt
407 from matplotlib import rcParams
408 old_font_size = rcParams["font.size"]
409 rcParams["font.size"] = 20
410 fig, axs = plt.subplots(
411 2, 1,
412 gridspec_kw=dict(height_ratios=[1.5, 1]),
413 figsize=(16, 12.5)
414 )
416 if PLOT_DATA:
417 if format == "html":
418 fig.add_trace(
419 go.Scatter(
420 x=plot_frequencies,
421 y=asd_from_freq_series(
422 interferometer.frequency_domain_strain[frequency_idxs],
423 1 / interferometer.strain_data.duration
424 ),
425 fill=None,
426 mode='lines', line_color=DATA_COLOR,
427 opacity=0.5,
428 name="Data",
429 legendgroup='data',
430 ),
431 row=1,
432 col=1,
433 )
434 fig.add_trace(
435 go.Scatter(
436 x=plot_frequencies,
437 y=interferometer.amplitude_spectral_density_array[frequency_idxs],
438 fill=None,
439 mode='lines', line_color=DATA_COLOR,
440 opacity=0.8,
441 name="ASD",
442 legendgroup='asd',
443 ),
444 row=1,
445 col=1,
446 )
447 fig.add_trace(
448 go.Scatter(
449 x=plot_times,
450 y=interferometer.whitened_time_domain_strain[time_idxs],
451 fill=None,
452 mode='lines', line_color=DATA_COLOR,
453 opacity=0.5,
454 name="Data",
455 legendgroup='data',
456 showlegend=False,
457 ),
458 row=2,
459 col=1,
460 )
461 else:
462 axs[0].loglog(
463 plot_frequencies,
464 asd_from_freq_series(
465 interferometer.frequency_domain_strain[frequency_idxs],
466 1 / interferometer.strain_data.duration),
467 color=DATA_COLOR, label='Data', alpha=0.3)
468 axs[0].loglog(
469 plot_frequencies,
470 interferometer.amplitude_spectral_density_array[frequency_idxs],
471 color=DATA_COLOR, label='ASD')
472 axs[1].plot(
473 plot_times, interferometer.whitened_time_domain_strain[time_idxs],
474 color=DATA_COLOR, alpha=0.3)
475 logger.debug('Plotted interferometer data.')
477 fd_waveforms = list()
478 td_waveforms = list()
479 for _, params in samples.iterrows():
480 params = dict(params)
481 wf_pols = waveform_generator.frequency_domain_strain(params)
482 fd_waveform = interferometer.get_detector_response(wf_pols, params)
483 fd_waveforms.append(fd_waveform[frequency_idxs])
484 whitened_fd_waveform = interferometer.whiten_frequency_series(fd_waveform)
485 td_waveform = interferometer.get_whitened_time_series_from_whitened_frequency_series(
486 whitened_fd_waveform
487 )[time_idxs]
488 td_waveforms.append(td_waveform)
489 fd_waveforms = asd_from_freq_series(
490 fd_waveforms,
491 1 / interferometer.strain_data.duration)
492 td_waveforms = np.array(td_waveforms)
494 delta = (1 + level) / 2
495 upper_percentile = delta * 100
496 lower_percentile = (1 - delta) * 100
497 logger.debug(
498 'Plotting posterior between the {} and {} percentiles'.format(
499 lower_percentile, upper_percentile
500 )
501 )
503 if format == "html":
504 fig.add_trace(
505 go.Scatter(
506 x=plot_frequencies, y=np.median(fd_waveforms, axis=0),
507 fill=None,
508 mode='lines', line_color=WAVEFORM_COLOR,
509 opacity=1,
510 name="Median reconstructed",
511 legendgroup='median',
512 ),
513 row=1,
514 col=1,
515 )
516 fig.add_trace(
517 go.Scatter(
518 x=plot_frequencies, y=np.percentile(fd_waveforms, lower_percentile, axis=0),
519 fill=None,
520 mode='lines',
521 line_color=WAVEFORM_COLOR,
522 opacity=0.1,
523 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
524 legendgroup='uncertainty',
525 ),
526 row=1,
527 col=1,
528 )
529 fig.add_trace(
530 go.Scatter(
531 x=plot_frequencies, y=np.percentile(fd_waveforms, upper_percentile, axis=0),
532 fill='tonexty',
533 mode='lines',
534 line_color=WAVEFORM_COLOR,
535 opacity=0.1,
536 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
537 legendgroup='uncertainty',
538 showlegend=False,
539 ),
540 row=1,
541 col=1,
542 )
543 fig.add_trace(
544 go.Scatter(
545 x=plot_times, y=np.median(td_waveforms, axis=0),
546 fill=None,
547 mode='lines', line_color=WAVEFORM_COLOR,
548 opacity=1,
549 name="Median reconstructed",
550 legendgroup='median',
551 showlegend=False,
552 ),
553 row=2,
554 col=1,
555 )
556 fig.add_trace(
557 go.Scatter(
558 x=plot_times, y=np.percentile(td_waveforms, lower_percentile, axis=0),
559 fill=None,
560 mode='lines',
561 line_color=WAVEFORM_COLOR,
562 opacity=0.1,
563 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
564 legendgroup='uncertainty',
565 showlegend=False,
566 ),
567 row=2,
568 col=1,
569 )
570 fig.add_trace(
571 go.Scatter(
572 x=plot_times, y=np.percentile(td_waveforms, upper_percentile, axis=0),
573 fill='tonexty',
574 mode='lines',
575 line_color=WAVEFORM_COLOR,
576 opacity=0.1,
577 name="{:.2f}% credible interval".format(upper_percentile - lower_percentile),
578 legendgroup='uncertainty',
579 showlegend=False,
580 ),
581 row=2,
582 col=1,
583 )
584 else:
585 lower_limit = np.mean(fd_waveforms, axis=0)[0] / 1e3
586 axs[0].loglog(
587 plot_frequencies,
588 np.mean(fd_waveforms, axis=0), color=WAVEFORM_COLOR, label='Mean reconstructed')
589 axs[0].fill_between(
590 plot_frequencies,
591 np.percentile(fd_waveforms, lower_percentile, axis=0),
592 np.percentile(fd_waveforms, upper_percentile, axis=0),
593 color=WAVEFORM_COLOR,
594 label=r'{}% credible interval'.format(int(upper_percentile - lower_percentile)),
595 alpha=0.3)
596 axs[1].plot(
597 plot_times, np.mean(td_waveforms, axis=0),
598 color=WAVEFORM_COLOR)
599 axs[1].fill_between(
600 plot_times, np.percentile(
601 td_waveforms, lower_percentile, axis=0),
602 np.percentile(td_waveforms, upper_percentile, axis=0),
603 color=WAVEFORM_COLOR,
604 alpha=0.3)
606 if self.injection_parameters is not None:
607 try:
608 hf_inj = waveform_generator.frequency_domain_strain(
609 self.injection_parameters)
610 hf_inj_det = interferometer.get_detector_response(
611 hf_inj, self.injection_parameters)
612 ht_inj_det = infft(
613 hf_inj_det * np.sqrt(2. / interferometer.sampling_frequency) /
614 interferometer.amplitude_spectral_density_array,
615 self.sampling_frequency)[time_idxs]
616 if format == "html":
617 fig.add_trace(
618 go.Scatter(
619 x=plot_frequencies,
620 y=asd_from_freq_series(
621 hf_inj_det[frequency_idxs],
622 1 / interferometer.strain_data.duration),
623 fill=None,
624 mode='lines',
625 line=dict(color=INJECTION_COLOR, dash='dot'),
626 name="Injection",
627 legendgroup='injection',
628 ),
629 row=1,
630 col=1,
631 )
632 fig.add_trace(
633 go.Scatter(
634 x=plot_times, y=ht_inj_det,
635 fill=None,
636 mode='lines',
637 line=dict(color=INJECTION_COLOR, dash='dot'),
638 name="Injection",
639 legendgroup='injection',
640 showlegend=False,
641 ),
642 row=2,
643 col=1,
644 )
645 else:
646 axs[0].loglog(
647 plot_frequencies,
648 asd_from_freq_series(
649 hf_inj_det[frequency_idxs],
650 1 / interferometer.strain_data.duration),
651 color=INJECTION_COLOR, label='Injection', linestyle=':')
652 axs[1].plot(
653 plot_times, ht_inj_det,
654 color=INJECTION_COLOR, linestyle=':')
655 logger.debug('Plotted injection.')
656 except IndexError as e:
657 logger.info('Failed to plot injection with message {}.'.format(e))
659 f_domain_x_label = "$f [\\mathrm{Hz}]$"
660 f_domain_y_label = "$\\mathrm{ASD} \\left[\\mathrm{Hz}^{-1/2}\\right]$"
661 t_domain_x_label = "$t - {} [s]$".format(interferometer.strain_data.start_time)
662 t_domain_y_label = "Whitened Strain"
663 if format == "html":
664 fig.update_xaxes(title_text=f_domain_x_label, type="log", row=1)
665 fig.update_yaxes(title_text=f_domain_y_label, type="log", row=1)
666 fig.update_xaxes(title_text=t_domain_x_label, type="linear", row=2)
667 fig.update_yaxes(title_text=t_domain_y_label, type="linear", row=2)
668 else:
669 axs[0].set_xlim(interferometer.minimum_frequency,
670 interferometer.maximum_frequency)
671 axs[1].set_xlim(start_time, end_time)
672 axs[0].set_ylim(lower_limit)
673 axs[0].set_xlabel(f_domain_x_label)
674 axs[0].set_ylabel(f_domain_y_label)
675 axs[1].set_xlabel(t_domain_x_label)
676 axs[1].set_ylabel(t_domain_y_label)
677 axs[0].legend(loc='lower left', ncol=2)
679 if save:
680 filename = os.path.join(
681 self.outdir,
682 self.label + '_{}_waveform.{}'.format(
683 interferometer.name, format))
684 if format == 'html':
685 plot(fig, filename=filename, include_mathjax='cdn', auto_open=False)
686 else:
687 plt.tight_layout()
688 safe_save_figure(
689 fig=fig, filename=filename,
690 format=format, dpi=600
691 )
692 plt.close()
693 logger.debug("Waveform figure saved to {}".format(filename))
694 rcParams["font.size"] = old_font_size
695 else:
696 rcParams["font.size"] = old_font_size
697 return fig
699 def plot_skymap(
700 self, maxpts=None, trials=5, jobs=1, enable_multiresolution=True,
701 objid=None, instruments=None, geo=False, dpi=600,
702 transparent=False, colorbar=False, contour=[50, 90],
703 annotate=True, cmap='cylon', load_pickle=False):
704 """ Generate a fits file and sky map from a result
706 Code adapted from ligo.skymap.tool.ligo_skymap_from_samples and
707 ligo.skymap.tool.plot_skymap. Note, the use of this additionally
708 required the installation of ligo.skymap.
710 Parameters
711 ==========
712 maxpts: int
713 Maximum number of samples to use, if None all samples are used
714 trials: int
715 Number of trials at each clustering number
716 jobs: int
717 Number of multiple threads
718 enable_multiresolution: bool
719 Generate a multiresolution HEALPix map (default: True)
720 objid: str
721 Event ID to store in FITS header
722 instruments: str
723 Name of detectors
724 geo: bool
725 Plot in geographic coordinates (lat, lon) instead of RA, Dec
726 dpi: int
727 Resolution of figure in fots per inch
728 transparent: bool
729 Save image with transparent background
730 colorbar: bool
731 Show colorbar
732 contour: list
733 List of contour levels to use
734 annotate: bool
735 Annotate image with details
736 cmap: str
737 Name of the colormap to use
738 load_pickle: bool, str
739 If true, load the cached pickle file (default name), or the
740 pickle-file give as a path.
741 """
742 import matplotlib.pyplot as plt
743 from matplotlib import rcParams
745 try:
746 from astropy.time import Time
747 from ligo.skymap import io, version, plot, postprocess, bayestar, kde
748 import healpy as hp
749 except ImportError as e:
750 logger.info("Unable to generate skymap: error {}".format(e))
751 return
753 check_directory_exists_and_if_not_mkdir(self.outdir)
755 logger.info('Reading samples for skymap')
756 data = self.posterior
758 if maxpts is not None and maxpts < len(data):
759 logger.info('Taking random subsample of chain')
760 data = data.sample(maxpts)
762 default_obj_filename = os.path.join(self.outdir, '{}_skypost.obj'.format(self.label))
764 if load_pickle is False:
765 try:
766 pts = data[['ra', 'dec', 'luminosity_distance']].values
767 confidence_levels = kde.Clustered2Plus1DSkyKDE
768 distance = True
769 except KeyError:
770 logger.warning("The results file does not contain luminosity_distance")
771 pts = data[['ra', 'dec']].values
772 confidence_levels = kde.Clustered2DSkyKDE
773 distance = False
775 logger.info('Initialising skymap class')
776 skypost = confidence_levels(pts, trials=trials, jobs=jobs)
777 logger.info('Pickling skymap to {}'.format(default_obj_filename))
778 safe_file_dump(skypost, default_obj_filename, "pickle")
780 else:
781 if isinstance(load_pickle, str):
782 obj_filename = load_pickle
783 else:
784 obj_filename = default_obj_filename
785 logger.info('Reading from pickle {}'.format(obj_filename))
786 with open(obj_filename, 'rb') as file:
787 skypost = pickle.load(file)
788 skypost.jobs = jobs
789 distance = isinstance(skypost, kde.Clustered2Plus1DSkyKDE)
791 logger.info('Making skymap')
792 hpmap = skypost.as_healpix()
793 if not enable_multiresolution:
794 hpmap = bayestar.rasterize(hpmap)
796 hpmap.meta.update(io.fits.metadata_for_version_module(version))
797 hpmap.meta['creator'] = "bilby"
798 hpmap.meta['origin'] = 'LIGO/Virgo'
799 hpmap.meta['gps_creation_time'] = Time.now().gps
800 hpmap.meta['history'] = ""
801 if objid is not None:
802 hpmap.meta['objid'] = objid
803 if instruments:
804 hpmap.meta['instruments'] = instruments
805 if distance:
806 hpmap.meta['distmean'] = np.mean(data['luminosity_distance'])
807 hpmap.meta['diststd'] = np.std(data['luminosity_distance'])
809 try:
810 time = data['geocent_time']
811 hpmap.meta['gps_time'] = time.mean()
812 except KeyError:
813 logger.warning('Cannot determine the event time from geocent_time')
815 fits_filename = os.path.join(self.outdir, "{}_skymap.fits".format(self.label))
816 logger.info('Saving skymap fits-file to {}'.format(fits_filename))
817 io.write_sky_map(fits_filename, hpmap, nest=True)
819 skymap, metadata = io.fits.read_sky_map(fits_filename, nest=None)
820 nside = hp.npix2nside(len(skymap))
822 # Convert sky map from probability to probability per square degree.
823 deg2perpix = hp.nside2pixarea(nside, degrees=True)
824 probperdeg2 = skymap / deg2perpix
826 if geo:
827 obstime = Time(metadata['gps_time'], format='gps').utc.isot
828 ax = plt.axes(projection='geo degrees mollweide', obstime=obstime)
829 else:
830 ax = plt.axes(projection='astro hours mollweide')
831 ax.grid()
833 # Plot sky map.
834 vmax = probperdeg2.max()
835 img = ax.imshow_hpx(
836 (probperdeg2, 'ICRS'), nested=metadata['nest'], vmin=0., vmax=vmax,
837 cmap=cmap)
839 # Add colorbar.
840 if colorbar:
841 cb = plot.colorbar(img)
842 cb.set_label(r'prob. per deg$^2$')
844 if contour is not None:
845 confidence_levels = 100 * postprocess.find_greedy_credible_levels(skymap)
846 contours = ax.contour_hpx(
847 (confidence_levels, 'ICRS'), nested=metadata['nest'],
848 colors='k', linewidths=0.5, levels=contour)
849 fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%'
850 plt.clabel(contours, fmt=fmt, fontsize=6, inline=True)
852 # Add continents.
853 if geo:
854 geojson_filename = os.path.join(
855 os.path.dirname(plot.__file__), 'ne_simplified_coastline.json')
856 with open(geojson_filename, 'r') as geojson_file:
857 geoms = json.load(geojson_file)['geometries']
858 verts = [coord for geom in geoms
859 for coord in zip(*geom['coordinates'])]
860 plt.plot(*verts, color='0.5', linewidth=0.5,
861 transform=ax.get_transform('world'))
863 # Add a white outline to all text to make it stand out from the background.
864 plot.outline_text(ax)
866 if annotate:
867 text = []
868 try:
869 objid = metadata['objid']
870 except KeyError:
871 pass
872 else:
873 text.append('event ID: {}'.format(objid))
874 if contour:
875 pp = np.round(contour).astype(int)
876 ii = np.round(np.searchsorted(np.sort(confidence_levels), contour) *
877 deg2perpix).astype(int)
878 for i, p in zip(ii, pp):
879 text.append(
880 u'{:d}% area: {:d} deg$^2$'.format(p, i))
881 ax.text(1, 1, '\n'.join(text), transform=ax.transAxes, ha='right')
883 filename = os.path.join(self.outdir, "{}_skymap.png".format(self.label))
884 logger.info("Generating 2D projected skymap to {}".format(filename))
885 safe_save_figure(fig=plt.gcf(), filename=filename, dpi=dpi)
888CBCResult = CompactBinaryCoalescenceResult