pycobertura report

Filename Stmts Miss Cover Missing
ligo/skymap/__init__.py 13 2 84.62% 32-33
ligo/skymap/distance.py 94 16 82.98% 465-478, 491, 640
ligo/skymap/healpix_tree.py 129 42 67.44% 95, 100-101, 110, 163-167, 185-212, 218-241, 292-304
ligo/skymap/kde.py 199 42 78.89% 68-73, 78-112, 130, 150-151, 216, 224-225, 263-267, 341, 355-356, 437, 443, 485
ligo/skymap/moc.py 56 1 98.21% 134
ligo/skymap/bayestar/__init__.py 229 42 81.66% 74, 83-88, 95-111, 123, 171, 224, 229, 244, 254, 290, 309, 388-389, 404-405, 492-507
ligo/skymap/bayestar/ez_emcee.py 36 30 16.67% 25, 107-156
ligo/skymap/bayestar/filter.py 194 10 94.85% 102, 252, 270-275, 282, 421, 459, 463
ligo/skymap/bayestar/interpolation.py 118 7 94.07% 208-214
ligo/skymap/bayestar/ptemcee.py 27 19 29.63% 12-26, 41-49, 54-57
ligo/skymap/coordinates/__init__.py 8 0 100.00%
ligo/skymap/coordinates/detector.py 21 0 100.00%
ligo/skymap/coordinates/eigenframe.py 31 5 83.87% 98-101, 113
ligo/skymap/io/__init__.py 8 0 100.00%
ligo/skymap/io/fits.py 149 18 87.92% 442, 447-448, 457, 494-500, 523-538
ligo/skymap/io/hdf5.py 68 2 97.06% 299-300
ligo/skymap/io/events/__init__.py 8 0 100.00%
ligo/skymap/io/events/base.py 62 14 77.42% 37-43, 46-53, 140, 147
ligo/skymap/io/events/detector_disabled.py 34 0 100.00%
ligo/skymap/io/events/gracedb.py 31 2 93.55% 32, 64
ligo/skymap/io/events/hdf.py 137 3 97.81% 59, 228-229
ligo/skymap/io/events/ligolw.py 178 14 92.13% 53, 66, 72-73, 136-137, 152-153, 187-188, 198, 219, 227, 250
ligo/skymap/io/events/magic.py 33 0 100.00%
ligo/skymap/io/events/sqlite.py 20 0 100.00%
ligo/skymap/plot/__init__.py 8 0 100.00%
ligo/skymap/plot/allsky.py 245 61 75.10% 220-221, 228-229, 245-254, 265-269, 274-293, 350, 379-388, 415, 442-449, 463-466, 502, 514-522, 582-584, 650-652, 748, 774
ligo/skymap/plot/angle.py 13 0 100.00%
ligo/skymap/plot/backdrop.py 46 24 47.83% 44-47, 87-89, 129-149, 187-208, 212-214
ligo/skymap/plot/bayes_factor.py 33 0 100.00%
ligo/skymap/plot/cmap.py 17 0 100.00%
ligo/skymap/plot/cylon.py 1 0 100.00%
ligo/skymap/plot/marker.py 21 0 100.00%
ligo/skymap/plot/poly.py 59 50 15.25% 32-42, 53-58, 69-157, 165-184
ligo/skymap/plot/pp.py 83 8 90.36% 98-102, 164, 167-168, 283
ligo/skymap/plot/util.py 31 2 93.55% 41, 49
ligo/skymap/postprocess/__init__.py 8 0 100.00%
ligo/skymap/postprocess/contour.py 60 4 93.33% 72-75
ligo/skymap/postprocess/cosmology.py 36 6 83.33% 34-37, 68-71
ligo/skymap/postprocess/crossmatch.py 149 7 95.30% 297, 359, 382, 418, 442-445
ligo/skymap/postprocess/ellipse.py 53 9 83.02% 330-336, 360, 377
ligo/skymap/postprocess/util.py 36 16 55.56% 55, 81-84, 88-94, 98-101
ligo/skymap/tool/__init__.py 196 47 76.02% 44-47, 61, 74-77, 91-94, 97, 100-107, 109-113, 192-195, 315-321, 390, 404, 426-447
ligo/skymap/tool/bayestar_inject.py 250 148 40.80% 300-555
ligo/skymap/tool/bayestar_localize_coincs.py 62 19 69.35% 108, 114-143, 146, 155, 169-173, 181
ligo/skymap/tool/bayestar_localize_lvalert.py 78 19 75.64% 103, 114, 121-123, 127-129, 145, 158, 162-172, 176-177
ligo/skymap/tool/bayestar_mcmc.py 74 57 22.97% 68, 74-194
ligo/skymap/tool/bayestar_realize_coincs.py 178 3 98.31% 346, 380, 385
ligo/skymap/tool/bayestar_sample_model_psd.py 53 0 100.00%
ligo/skymap/tool/ligo_skymap_combine.py 76 4 94.74% 70, 96, 102, 107
ligo/skymap/tool/ligo_skymap_constellations.py 27 18 33.33% 42-61
ligo/skymap/tool/ligo_skymap_contour.py 31 18 41.94% 53-81
ligo/skymap/tool/ligo_skymap_contour_moc.py 25 14 44.00% 47-77
ligo/skymap/tool/ligo_skymap_flatten.py 34 2 94.12% 56-57
ligo/skymap/tool/ligo_skymap_from_samples.py 95 13 86.32% 112-115, 118-119, 137, 143, 150-151, 156, 176-177, 182
ligo/skymap/tool/ligo_skymap_plot.py 79 5 93.67% 136-147, 162-163
ligo/skymap/tool/ligo_skymap_plot_airmass.py 100 77 23.00% 64, 68, 76-219
ligo/skymap/tool/ligo_skymap_plot_coherence.py 23 13 43.48% 39-56
ligo/skymap/tool/ligo_skymap_plot_observability.py 73 49 32.88% 67, 75-148
ligo/skymap/tool/ligo_skymap_plot_pp_samples.py 59 44 25.42% 29-35, 62-128
ligo/skymap/tool/ligo_skymap_plot_stats.py 125 3 97.60% 70, 135, 137
ligo/skymap/tool/ligo_skymap_plot_volume.py 122 10 91.80% 88, 94, 98-99, 122, 163, 235-236, 242-243
ligo/skymap/tool/ligo_skymap_stats.py 82 9 89.02% 145-146, 151, 168, 172, 174, 190, 192, 220
ligo/skymap/tool/ligo_skymap_unflatten.py 20 11 45.00% 36-48
ligo/skymap/tool/matplotlib.py 77 6 92.21% 35-36, 48-49, 72-73
ligo/skymap/util/__init__.py 8 0 100.00%
ligo/skymap/util/file.py 25 12 52.00% 30-42
ligo/skymap/util/ilwd.py 49 4 91.84% 79-81, 132
ligo/skymap/util/math.py 4 0 100.00%
ligo/skymap/util/numpy.py 11 0 100.00%
ligo/skymap/util/progress.py 40 3 92.50% 39-40, 70
ligo/skymap/util/sqlite.py 34 1 97.06% 152
ligo/skymap/util/stopwatch.py 48 12 75.00% 38-41, 44, 54, 57-59, 67, 86-88
src/bayestar_distance.c 253 6 97.63% 93-94, 126-129, 455-456
src/bayestar_moc.c 56 4 92.86% 117, 126, 141-142
src/bayestar_sky_map.c 606 53 91.25% 250, 266, 332, 431, 449-453, 655, 670-671, 710, 902, 932-933, 943-944, 990, 1006, 1022, 1048-1049, 1111-1181
src/cubic_interp.c 86 0 100.00%
src/cubic_interp_test.c 187 0 100.00%
src/omp_interruptible.h 16 6 62.50% 157-163
src/vmath.h 7 0 100.00%
TOTAL 6151 1146 81.37%

ligo/skymap/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
# This file is adapted from the Astropy package template, which is licensed
# under a 3-clause BSD style license - see licenses/TEMPLATE_LICENSE.rst

# Packages may add whatever they like to this file, but
# should keep this content at the top.
# ----------------------------------------------------------------------------
from ._astropy_init import *   # noqa
# ----------------------------------------------------------------------------

__all__ = ('omp',)


class Omp:
    """OpenMP runtime settings.

    Attributes
    ----------
    num_threads : int
        Adjust the number of OpenMP threads. Getting and setting this attribute
        call :man:`omp_get_num_threads` and :man:`omp_set_num_threads`
        respectively.

    """

    @property
    def num_threads(self):
        from .core import get_num_threads
        return get_num_threads()

    @num_threads.setter
    def num_threads(self, value):
        from .core import set_num_threads
        set_num_threads(value)


omp = Omp()
del Omp

ligo/skymap/distance.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
566  
567  
568  
569  
570  
571  
572  
573  
574  
575  
576  
577  
578  
579  
580  
581  
582  
583  
584  
585  
586  
587  
588  
589  
590  
591  
592  
593  
594  
595  
596  
597  
598  
599  
600  
601  
602  
603  
604  
605  
606  
607  
608  
609  
610  
611  
612  
613  
614  
615  
616  
617  
618  
619  
620  
621  
622  
623  
624  
625  
626  
627  
628  
629  
630  
631  
632  
633  
634  
635  
636  
637  
638  
639  
640  
641  
642  
643  
644  
645  
646  
647  
648  
649  
650  
651  
652  
653  
654  
655  
656  
657  
658  
659  
660  
661  
662  
663  
664  
665  
666  
667  
668  
669  
670  
671  
672  
673  
674  
675  
#
# Copyright (C) 2017-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Distance distribution functions from [1]_, [2]_, [3]_.

References
----------
.. [1] Singer, Chen, & Holz, 2016. "Going the Distance: Mapping Host Galaxies
   of LIGO and Virgo Sources in Three Dimensions Using Local Cosmography and
   Targeted Follow-up." ApJL, 829, L15.
   :doi:`10.3847/2041-8205/829/1/L15`

.. [2] Singer, Chen, & Holz, 2016. "Supplement: 'Going the Distance: Mapping
   Host Galaxies of LIGO and Virgo Sources in Three Dimensions Using Local
   Cosmography and Targeted Follow-up' (2016, ApJL, 829, L15)." ApJS, 226, 10.
   :doi:`10.3847/0067-0049/226/1/10`

.. [3] https://asd.gsfc.nasa.gov/Leo.Singer/going-the-distance

"""

import astropy_healpix as ah
import numpy as np
import healpy as hp
import scipy.special
from .core import (conditional_pdf, conditional_cdf, conditional_ppf,
                   moments_to_parameters, parameters_to_moments, volume_render,
                   marginal_pdf, marginal_cdf, marginal_ppf)
from .util.numpy import add_newdoc_ufunc, require_contiguous_aligned

__all__ = ('conditional_pdf', 'conditional_cdf', 'conditional_ppf',
           'moments_to_parameters', 'parameters_to_moments', 'volume_render',
           'marginal_pdf', 'marginal_cdf', 'marginal_ppf', 'ud_grade',
           'conditional_kde', 'cartesian_kde_to_moments', 'principal_axes',
           'parameters_to_moments')


add_newdoc_ufunc(conditional_pdf, """\
Conditional distance probability density function (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
pdf : `numpy.ndarray`
    Conditional probability density according to ansatz.

""")


add_newdoc_ufunc(conditional_cdf, """\
Cumulative conditional distribution of distance (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
pdf : `numpy.ndarray`
    Conditional probability density according to ansatz.

Examples
--------
Test against numerical integral of pdf.

>>> import scipy.integrate
>>> distmu = 10.0
>>> distsigma = 5.0
>>> distnorm = 1.0
>>> r = 8.0
>>> expected, _ = scipy.integrate.quad(
...     conditional_pdf, 0, r,
...     (distmu, distsigma, distnorm))
>>> result = conditional_cdf(
...     r, distmu, distsigma, distnorm)
>>> np.testing.assert_almost_equal(result, expected)

""")


add_newdoc_ufunc(conditional_ppf, """\
Point percent function (inverse cdf) of distribution of distance (ansatz).

Parameters
----------
p : `numpy.ndarray`
    The cumulative distribution function
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
r : `numpy.ndarray`
    Distance at which the cdf is equal to `p`.

Examples
--------
Test against numerical estimate.

>>> import scipy.optimize
>>> distmu = 10.0
>>> distsigma = 5.0
>>> distnorm = 1.0
>>> p = 0.16  # "one-sigma" lower limit
>>> expected_r16 = scipy.optimize.brentq(
... lambda r: conditional_cdf(r, distmu, distsigma, distnorm) - p, 0.0, 100.0)
>>> r16 = conditional_ppf(p, distmu, distsigma, distnorm)
>>> np.testing.assert_almost_equal(r16, expected_r16)

""")


add_newdoc_ufunc(moments_to_parameters, """\
Convert ansatz moments to parameters.

This function is the inverse of `parameters_to_moments`.

Parameters
----------
distmean : `numpy.ndarray`
    Conditional mean of distance (Mpc)
diststd : `numpy.ndarray`
    Conditional standard deviation of distance (Mpc)

Returns
-------
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

""")


add_newdoc_ufunc(parameters_to_moments, """\
Convert ansatz parameters to moments.

This function is the inverse of `moments_to_parameters`.

Parameters
----------
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)

Returns
-------
distmean : `numpy.ndarray`
    Conditional mean of distance (Mpc)
diststd : `numpy.ndarray`
    Conditional standard deviation of distance (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Examples
--------
For mu=0, sigma=1, the ansatz is a chi distribution with 3 degrees of
freedom, and the moments have simple expressions.

>>> mean, std, norm = parameters_to_moments(0, 1)
>>> expected_mean = 2 * np.sqrt(2 / np.pi)
>>> expected_std = np.sqrt(3 - expected_mean**2)
>>> expected_norm = 2.0
>>> np.testing.assert_allclose(mean, expected_mean)
>>> np.testing.assert_allclose(std, expected_std)
>>> np.testing.assert_allclose(norm, expected_norm)

Check that the moments scale as expected when we vary sigma.

>>> sigma = np.logspace(-8, 8)
>>> mean, std, norm = parameters_to_moments(0, sigma)
>>> np.testing.assert_allclose(mean, expected_mean * sigma)
>>> np.testing.assert_allclose(std, expected_std * sigma)
>>> np.testing.assert_allclose(norm, expected_norm / sigma**2)

Check some more arbitrary values using numerical quadrature:

>>> import scipy.integrate
>>> sigma = 1.0
>>> for mu in np.linspace(-10, 10):
...     mean, std, norm = parameters_to_moments(mu, sigma)
...     moments = np.empty(3)
...     for k in range(3):
...         moments[k], _ = scipy.integrate.quad(
...             lambda r: r**k * conditional_pdf(r, mu, sigma, 1.0),
...             0, np.inf)
...     expected_norm = 1 / moments[0]
...     expected_mean, r2 = moments[1:] * expected_norm
...     expected_std = np.sqrt(r2 - np.square(expected_mean))
...     np.testing.assert_approx_equal(mean, expected_mean, 5)
...     np.testing.assert_approx_equal(std, expected_std, 5)
...     np.testing.assert_approx_equal(norm, expected_norm, 5)

""")


add_newdoc_ufunc(volume_render, """\
Perform volumetric rendering of a 3D sky map.

Parameters
----------
x : `numpy.ndarray`
    X-coordinate in rendered image
y : `numpy.ndarray`
    Y-coordinate in rendered image
max_distance : float
    Limit of integration from `-max_distance` to `+max_distance`
axis0 : int
    Index of axis to assign to x-coordinate
axis1 : int
    Index of axis to assign to y-coordinate
R : `numpy.ndarray`
    Rotation matrix as provided by `principal_axes`
nest : bool
    HEALPix ordering scheme
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
image : `numpy.ndarray`
    Rendered image

Examples
--------
Test volume rendering of a normal unit sphere...
First, set up the 3D sky map.

>>> nside = 32
>>> npix = ah.nside_to_npix(nside)
>>> prob = np.ones(npix) / npix
>>> distmu = np.zeros(npix)
>>> distsigma = np.ones(npix)
>>> distnorm = np.ones(npix) * 2.0

The conditional distance distribution should be a chi distribution with
3 degrees of freedom.

>>> from scipy.stats import norm, chi
>>> r = np.linspace(0, 10.0)
>>> actual = conditional_pdf(r, distmu[0], distsigma[0], distnorm[0])
>>> expected = chi(3).pdf(r)
>>> np.testing.assert_almost_equal(actual, expected)

Next, run the volume renderer.

>>> dmax = 4.0
>>> n = 64
>>> s = np.logspace(-dmax, dmax, n)
>>> x, y = np.meshgrid(s, s)
>>> R = np.eye(3)
>>> P = volume_render(x, y, dmax, 0, 1, R, False,
...                   prob, distmu, distsigma, distnorm)

Next, integrate analytically.

>>> P_expected = norm.pdf(x) * norm.pdf(y) * (norm.cdf(dmax) - norm.cdf(-dmax))

Compare the two.

>>> np.testing.assert_almost_equal(P, P_expected, decimal=4)

Check that we get the same answer if the input is in ring ordering.
FIXME: this is a very weak test, because the input sky map is isotropic!

>>> P = volume_render(x, y, dmax, 0, 1, R, True,
...                   prob, distmu, distsigma, distnorm)
>>> np.testing.assert_almost_equal(P, P_expected, decimal=4)

Last, check that we don't have a coordinate singularity at the origin.

>>> x = np.concatenate(([0], np.logspace(1 - n, 0, n) * dmax))
>>> y = 0.0
>>> P = volume_render(x, y, dmax, 0, 1, R, False,
...                   prob, distmu, distsigma, distnorm)
>>> P_expected = norm.pdf(x) * norm.pdf(y) * (norm.cdf(dmax) - norm.cdf(-dmax))
>>> np.testing.assert_allclose(P, P_expected, rtol=1e-4)

""")
volume_render = require_contiguous_aligned(volume_render)


add_newdoc_ufunc(marginal_pdf, """\
Calculate all-sky marginal pdf (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
pdf : `numpy.ndarray`
    Marginal probability density according to ansatz.

Examples
--------

>>> npix = 12
>>> prob, distmu, distsigma, distnorm = np.random.uniform(size=(4, 12))
>>> r = np.linspace(0, 1)
>>> pdf_expected = np.dot(
...     conditional_pdf(r[:, np.newaxis], distmu, distsigma, distnorm), prob)
>>> pdf = marginal_pdf(r, prob, distmu, distsigma, distnorm)
>>> np.testing.assert_allclose(pdf, pdf_expected, rtol=1e-4)

""")
marginal_pdf = require_contiguous_aligned(marginal_pdf)


add_newdoc_ufunc(marginal_cdf, """\
Calculate all-sky marginal cdf (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
cdf : `numpy.ndarray`
    Marginal cumulative probability according to ansatz.

Examples
--------

>>> npix = 12
>>> prob, distmu, distsigma, distnorm = np.random.uniform(size=(4, 12))
>>> r = np.linspace(0, 1)
>>> cdf_expected = np.dot(
...     conditional_cdf(r[:, np.newaxis], distmu, distsigma, distnorm), prob)
>>> cdf = marginal_cdf(r, prob, distmu, distsigma, distnorm)
>>> np.testing.assert_allclose(cdf, cdf_expected, rtol=1e-4)

""")
marginal_cdf = require_contiguous_aligned(marginal_cdf)


add_newdoc_ufunc(marginal_ppf, """\
Point percent function (inverse cdf) of marginal distribution of distance
(ansatz).

Parameters
----------
p : `numpy.ndarray`
    The cumulative distribution function
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
r : `numpy.ndarray`
    Distance at which the cdf is equal to `p`.

Examples
--------

>>> from astropy.utils.misc import NumpyRNGContext
>>> npix = 12
>>> with NumpyRNGContext(0):
...     prob, distmu, distsigma, distnorm = np.random.uniform(size=(4, 12))
>>> r_expected = np.linspace(0.4, 0.7)
>>> cdf = marginal_cdf(r_expected, prob, distmu, distsigma, distnorm)
>>> r = marginal_ppf(cdf, prob, distmu, distsigma, distnorm)
>>> np.testing.assert_allclose(r, r_expected, rtol=1e-4)

""")
marginal_ppf = require_contiguous_aligned(marginal_ppf)


def ud_grade(prob, distmu, distsigma, *args, **kwargs):
    """
    Upsample or downsample a distance-resolved sky map.

    Parameters
    ----------
    prob : `numpy.ndarray`
        Marginal probability (pix^-2)
    distmu : `numpy.ndarray`
        Distance location parameter (Mpc)
    distsigma : `numpy.ndarray`
        Distance scale parameter (Mpc)
    *args, **kwargs :
        Additional arguments to `healpy.ud_grade` (e.g.,
        `nside`, `order_in`, `order_out`).

    Returns
    -------
    prob : `numpy.ndarray`
        Resampled marginal probability (pix^-2)
    distmu : `numpy.ndarray`
        Resampled distance location parameter (Mpc)
    distsigma : `numpy.ndarray`
        Resampled distance scale parameter (Mpc)
    distnorm : `numpy.ndarray`
        Resampled distance normalization factor (Mpc^-2)

    """
    bad = ~(np.isfinite(distmu) & np.isfinite(distsigma))
    distmean, diststd, _ = parameters_to_moments(distmu, distsigma)
    distmean[bad] = 0
    diststd[bad] = 0
    distmean = hp.ud_grade(prob * distmu, *args, power=-2, **kwargs)
    diststd = hp.ud_grade(prob * np.square(diststd), *args, power=-2, **kwargs)
    prob = hp.ud_grade(prob, *args, power=-2, **kwargs)
    distmean /= prob
    diststd = np.sqrt(diststd / prob)
    bad = ~hp.ud_grade(~bad, *args, power=-2, **kwargs)
    distmean[bad] = np.inf
    diststd[bad] = 1
    distmu, distsigma, distnorm = moments_to_parameters(distmean, diststd)
    return prob, distmu, distsigma, distnorm


def _conditional_kde(n, X, Cinv, W):
    Cinv_n = np.dot(Cinv, n)
    cinv = np.dot(n, Cinv_n)
    x = np.dot(Cinv_n, X) / cinv
    w = W * (0.5 / np.pi) * np.sqrt(np.linalg.det(Cinv) / cinv) * np.exp(
        0.5 * (np.square(x) * cinv - (np.dot(Cinv, X) * X).sum(0)))
    return x, cinv, w


def conditional_kde(n, datasets, inverse_covariances, weights):
    return [
        _conditional_kde(n, X, Cinv, W)
        for X, Cinv, W in zip(datasets, inverse_covariances, weights)]


def cartesian_kde_to_moments(n, datasets, inverse_covariances, weights):
    """
    Calculate the marginal probability, conditional mean, and conditional
    standard deviation of a mixture of three-dimensional kernel density
    estimators (KDEs), in a given direction specified by a unit vector.

    Parameters
    ----------
    n : `numpy.ndarray`
        A unit vector; an array of length 3.
    datasets : list of `numpy.ndarray`
        A list 2D Numpy arrays specifying the sample points of the KDEs.
        The first dimension of each array is 3.
    inverse_covariances: list of `numpy.ndarray`
        An array of 3x3 matrices specifying the inverses of the covariance
        matrices of the KDEs. The list has the same length as the datasets
        parameter.
    weights : list
        A list of floating-point weights.

    Returns
    -------
    prob : float
        The marginal probability in direction n, integrated over all distances.
    mean : float
        The conditional mean in direction n.
    std : float
        The conditional standard deviation in direction n.

    Examples
    --------
    >>> # Some imports
    >>> import scipy.stats
    >>> import scipy.integrate
    >>> # Construct random dataset for KDE
    >>> np.random.seed(0)
    >>> nclusters = 5
    >>> ndata = np.random.randint(0, 1000, nclusters)
    >>> covs = [np.random.uniform(0, 1, size=(3, 3)) for _ in range(nclusters)]
    >>> covs = [_ + _.T + 3 * np.eye(3) for _ in covs]
    >>> means = np.random.uniform(-1, 1, size=(nclusters, 3))
    >>> datasets = [np.random.multivariate_normal(m, c, n).T
    ...     for m, c, n in zip(means, covs, ndata)]
    >>> weights = ndata / float(np.sum(ndata))
    >>>
    >>> # Construct set of KDEs
    >>> kdes = [scipy.stats.gaussian_kde(_) for _ in datasets]
    >>>
    >>> # Random unit vector n
    >>> n = np.random.normal(size=3)
    >>> n /= np.sqrt(np.sum(np.square(n)))
    >>>
    >>> # Analytically evaluate conditional mean and std. dev. in direction n
    >>> datasets = [_.dataset for _ in kdes]
    >>> inverse_covariances = [_.inv_cov for _ in kdes]
    >>> result_prob, result_mean, result_std = cartesian_kde_to_moments(
    ...     n, datasets, inverse_covariances, weights)
    >>>
    >>> # Numerically integrate conditional distance moments
    >>> def rkbar(k):
    ...     def integrand(r):
    ...         return r ** k * np.sum([kde(r * n) * weight
    ...             for kde, weight in zip(kdes, weights)])
    ...     integral, err = scipy.integrate.quad(integrand, 0, np.inf)
    ...     return integral
    ...
    >>> r0bar = rkbar(2)
    >>> r1bar = rkbar(3)
    >>> r2bar = rkbar(4)
    >>>
    >>> # Extract conditional mean and std. dev.
    >>> r1bar /= r0bar
    >>> r2bar /= r0bar
    >>> expected_prob = r0bar
    >>> expected_mean = r1bar
    >>> expected_std = np.sqrt(r2bar - np.square(r1bar))
    >>>
    >>> # Check that the two methods give almost the same result
    >>> np.testing.assert_almost_equal(result_prob, expected_prob)
    >>> np.testing.assert_almost_equal(result_mean, expected_mean)
    >>> np.testing.assert_almost_equal(result_std, expected_std)
    >>>
    >>> # Check that KDE is normalized over unit sphere.
    >>> nside = 32
    >>> npix = ah.nside_to_npix(nside)
    >>> prob, _, _ = np.transpose([cartesian_kde_to_moments(
    ...     np.asarray(hp.pix2vec(nside, ipix)),
    ...     datasets, inverse_covariances, weights)
    ...     for ipix in range(npix)])
    >>> result_integral = prob.sum() * hp.nside2pixarea(nside)
    >>> np.testing.assert_almost_equal(result_integral, 1.0, decimal=4)

    """
    # Initialize moments of conditional KDE.
    r0bar = 0
    r1bar = 0
    r2bar = 0

    # Loop over KDEs.
    for X, Cinv, W in zip(datasets, inverse_covariances, weights):
        x, cinv, w = _conditional_kde(n, X, Cinv, W)

        # Accumulate moments of conditional KDE.
        c = 1 / cinv
        x2 = np.square(x)
        a = scipy.special.ndtr(x * np.sqrt(cinv))
        b = np.sqrt(0.5 / np.pi * c) * np.exp(-0.5 * cinv * x2)
        r0bar_ = (x2 + c) * a + x * b
        r1bar_ = x * (x2 + 3 * c) * a + (x2 + 2 * c) * b
        r2bar_ = (x2 * x2 + 6 * x2 * c + 3 * c * c) * a + x * (x2 + 5 * c) * b
        r0bar += np.mean(w * r0bar_)
        r1bar += np.mean(w * r1bar_)
        r2bar += np.mean(w * r2bar_)

    # Normalize moments.
    with np.errstate(invalid='ignore'):
        r1bar /= r0bar
        r2bar /= r0bar
    var = r2bar - np.square(r1bar)

    # Handle invalid values.
    if var >= 0:
        mean = r1bar
        std = np.sqrt(var)
    else:
        mean = np.inf
        std = 1.0
    prob = r0bar

    # Done!
    return prob, mean, std


def principal_axes(prob, distmu, distsigma, nest=False):
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    good = np.isfinite(prob) & np.isfinite(distmu) & np.isfinite(distsigma)
    ipix = np.flatnonzero(good)
    distmean, diststd, _ = parameters_to_moments(distmu[good], distsigma[good])
    mass = prob[good] * (np.square(diststd) + np.square(distmean))
    xyz = np.asarray(hp.pix2vec(nside, ipix, nest=nest))
    cov = np.dot(xyz * mass, xyz.T)
    L, V = np.linalg.eigh(cov)
    if np.linalg.det(V) < 0:
        V = -V
    return V


def parameters_to_marginal_moments(prob, distmu, distsigma):
    """Calculate the marginal (integrated all-sky) mean and standard deviation
    of distance from the ansatz parameters.

    Parameters
    ----------
    prob : `numpy.ndarray`
        Marginal probability (pix^-2)
    distmu : `numpy.ndarray`
        Distance location parameter (Mpc)
    distsigma : `numpy.ndarray`
        Distance scale parameter (Mpc)

    Returns
    -------
    distmean : float
        Mean distance (Mpc)
    diststd : float
        Std. deviation of distance (Mpc)

    """
    good = np.isfinite(prob) & np.isfinite(distmu) & np.isfinite(distsigma)
    prob = prob[good]
    distmu = distmu[good]
    distsigma = distsigma[good]
    distmean, diststd, _ = parameters_to_moments(distmu, distsigma)
    rbar = (prob * distmean).sum()
    r2bar = (prob * (np.square(diststd) + np.square(distmean))).sum()
    return rbar, np.sqrt(r2bar - np.square(rbar))


del add_newdoc_ufunc, require_contiguous_aligned

ligo/skymap/healpix_tree.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Multiresolution HEALPix trees
"""
import astropy_healpix as ah
from astropy import units as u
import numpy as np
import healpy as hp
import collections
import itertools

__all__ = ('HEALPIX_MACHINE_ORDER', 'HEALPIX_MACHINE_NSIDE', 'HEALPixTree',
           'adaptive_healpix_histogram', 'interpolate_nested',
           'reconstruct_nested')


# Maximum 64-bit HEALPix resolution.
HEALPIX_MACHINE_ORDER = 29
HEALPIX_MACHINE_NSIDE = ah.level_to_nside(HEALPIX_MACHINE_ORDER)


_HEALPixTreeVisitExtra = collections.namedtuple(
    'HEALPixTreeVisit', 'nside full_nside ipix ipix0 ipix1 value')


_HEALPixTreeVisit = collections.namedtuple(
    'HEALPixTreeVisit', 'nside ipix')


class HEALPixTree:
    """Data structure used internally by the function
    adaptive_healpix_histogram()."""

    def __init__(
            self, samples, max_samples_per_pixel, max_order,
            order=0, needs_sort=True):
        if needs_sort:
            samples = np.sort(samples)
        if len(samples) >= max_samples_per_pixel and order < max_order:
            # All nodes have 4 children, except for the root node,
            # which has 12.
            nchildren = 12 if order == 0 else 4
            self.samples = None
            self.children = [
                HEALPixTree(
                    [], max_samples_per_pixel, max_order, order=order + 1)
                for i in range(nchildren)]
            for ipix, samples in itertools.groupby(
                    samples, self.key_for_order(order)):
                self.children[ipix % nchildren] = HEALPixTree(
                    list(samples), max_samples_per_pixel, max_order,
                    order=order + 1, needs_sort=False)
        else:
            # There are few enough samples that we can make this cell a leaf.
            self.samples = list(samples)
            self.children = None

    @staticmethod
    def key_for_order(order):
        """Create a function that downsamples full-resolution pixel indices."""
        return lambda ipix: ipix >> np.int64(
            2 * (HEALPIX_MACHINE_ORDER - order))

    @property
    def order(self):
        """Return the maximum HEALPix order required to represent this tree,
        which is the same as the tree depth."""
        if self.children is None:
            return 0
        else:
            return 1 + max(child.order for child in self.children)

    def _visit(self, order, full_order, ipix, extra):
        if self.children is None:
            nside = 1 << order
            full_nside = 1 << order
            ipix0 = ipix << 2 * (full_order - order)
            ipix1 = (ipix + 1) << 2 * (full_order - order)
            if extra:
                yield _HEALPixTreeVisitExtra(
                    nside, full_nside, ipix, ipix0, ipix1, self.samples)
            else:
                yield _HEALPixTreeVisit(nside, ipix)
        else:
            for i, child in enumerate(self.children):
                yield from child._visit(
                    order + 1, full_order, (ipix << 2) + i, extra)

    def _visit_depthfirst(self, extra):
        order = self.order
        for ipix, child in enumerate(self.children):
            yield from child._visit(0, order, ipix, extra)

    def _visit_breadthfirst(self, extra):
        return sorted(
            self._visit_depthfirst(extra), lambda _: (_.nside, _.ipix))

    def visit(self, order='depthfirst', extra=True):
        """Traverse the leaves of the HEALPix tree.

        Parameters
        ----------
        order : string, optional
            Traversal order: 'depthfirst' (the default) or 'breadthfirst'.

        extra : bool
            Whether to output extra information about the pixel
            (default is True).

        Yields
        ------
        nside : int
            The HEALPix resolution of the node.

        full_nside : int, present if extra=True
            The HEALPix resolution of the deepest node in the tree.

        ipix : int
            The nested HEALPix index of the node.

        ipix0 : int, present if extra=True
            The start index of the range of pixels spanned by the node at the
            resolution `full_nside`.

        ipix1 : int, present if extra=True
            The end index of the range of pixels spanned by the node at the
            resolution `full_nside`.

        samples : list, present if extra=True
            The list of samples contained in the node.

        Examples
        --------

        >>> ipix = np.arange(12) * HEALPIX_MACHINE_NSIDE**2
        >>> tree = HEALPixTree(ipix, max_samples_per_pixel=1, max_order=1)
        >>> [tuple(_) for _ in tree.visit(extra=False)]
        [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]
        """
        funcs = {'depthfirst': self._visit_depthfirst,
                 'breadthfirst': self._visit_breadthfirst}
        func = funcs[order]
        yield from func(extra)

    @property
    def flat_bitmap(self):
        """Return flattened HEALPix representation."""
        m = np.empty(ah.nside_to_npix(ah.level_to_nside(self.order)))
        for nside, full_nside, ipix, ipix0, ipix1, samples in self.visit():
            pixarea = ah.nside_to_pixel_area(nside).to_value(u.sr)
            m[ipix0:ipix1] = len(samples) / pixarea
        return m


def adaptive_healpix_histogram(
        theta, phi, max_samples_per_pixel, nside=-1, max_nside=-1, nest=False):
    """Adaptively histogram the posterior samples represented by the
    (theta, phi) points using a recursively subdivided HEALPix tree. Nodes are
    subdivided until each leaf contains no more than max_samples_per_pixel
    samples. Finally, the tree is flattened to a fixed-resolution HEALPix image
    with a resolution appropriate for the depth of the tree. If nside is
    specified, the result is resampled to another desired HEALPix resolution.
    """
    # Calculate pixel index of every sample, at the maximum 64-bit resolution.
    #
    # At this resolution, each pixel is only 0.2 mas across; we'll use the
    # 64-bit pixel indices as a proxy for the true sample coordinates so that
    # we don't have to do any trigonometry (aside from the initial hp.ang2pix
    # call).
    ipix = hp.ang2pix(HEALPIX_MACHINE_NSIDE, theta, phi, nest=True)

    # Build tree structure.
    if nside == -1 and max_nside == -1:
        max_order = HEALPIX_MACHINE_ORDER
    elif nside == -1:
        max_order = ah.nside_to_level(max_nside)
    elif max_nside == -1:
        max_order = ah.nside_to_level(nside)
    else:
        max_order = ah.nside_to_level(min(nside, max_nside))
    tree = HEALPixTree(ipix, max_samples_per_pixel, max_order)

    # Compute a flattened bitmap representation of the tree.
    p = tree.flat_bitmap

    # If requested, resample the tree to the output resolution.
    if nside != -1:
        p = hp.ud_grade(p, nside, order_in='NESTED', order_out='NESTED')

    # Normalize.
    p /= np.sum(p)

    if not nest:
        p = hp.reorder(p, n2r=True)

    # Done!
    return p


def _interpolate_level(m):
    """Recursive multi-resolution interpolation. Modifies `m` in place."""
    # Determine resolution.
    npix = len(m)

    if npix > 12:
        # Determine which pixels comprise multi-pixel tiles.
        ipix = np.flatnonzero(
            (m[0::4] == m[1::4]) &
            (m[0::4] == m[2::4]) &
            (m[0::4] == m[3::4]))

        if len(ipix):
            ipix = 4 * ipix + np.expand_dims(np.arange(4, dtype=np.intp), 1)
            ipix = ipix.T.ravel()

            nside = ah.npix_to_nside(npix)

            # Downsample.
            m_lores = hp.ud_grade(
                m, nside // 2, order_in='NESTED', order_out='NESTED')

            # Interpolate recursively.
            _interpolate_level(m_lores)

            # Record interpolated multi-pixel tiles.
            m[ipix] = hp.get_interp_val(
                m_lores, *hp.pix2ang(nside, ipix, nest=True), nest=True)


def interpolate_nested(m, nest=False):
    """
    Apply bilinear interpolation to a multiresolution HEALPix map, assuming
    that runs of pixels containing identical values are nodes of the tree. This
    smooths out the stair-step effect that may be noticeable in contour plots.

    Here is how it works. Consider a coarse tile surrounded by base tiles, like
    this::

                +---+---+
                |   |   |
                +-------+
                |   |   |
        +---+---+---+---+---+---+
        |   |   |       |   |   |
        +-------+       +-------+
        |   |   |       |   |   |
        +---+---+---+---+---+---+
                |   |   |
                +-------+
                |   |   |
                +---+---+

    The value within the central coarse tile is computed by downsampling the
    sky map (averaging the fine tiles), upsampling again (with bilinear
    interpolation), and then finally copying the interpolated values within the
    coarse tile back to the full-resolution sky map. This process is applied
    recursively at all successive HEALPix resolutions.

    Note that this method suffers from a minor discontinuity artifact at the
    edges of regions of coarse tiles, because it temporarily treats the
    bordering fine tiles as constant. However, this artifact seems to have only
    a minor effect on generating contour plots.

    Parameters
    ----------

    m: `~numpy.ndarray`
        a HEALPix array

    nest: bool, default: False
        Whether the input array is stored in the `NESTED` indexing scheme
        (True) or the `RING` indexing scheme (False).

    """
    # Convert to nest indexing if necessary, and make sure that we are working
    # on a copy.
    if nest:
        m = m.copy()
    else:
        m = hp.reorder(m, r2n=True)

    _interpolate_level(m)

    # Convert to back ring indexing if necessary
    if not nest:
        m = hp.reorder(m, n2r=True)

    # Done!
    return m


def _reconstruct_nested_breadthfirst(m, extra):
    m = np.asarray(m)
    max_npix = len(m)
    max_nside = ah.npix_to_nside(max_npix)
    max_order = ah.nside_to_level(max_nside)
    seen = np.zeros(max_npix, dtype=bool)

    for order in range(max_order + 1):
        nside = ah.level_to_nside(order)
        npix = ah.nside_to_npix(nside)
        skip = max_npix // npix
        if skip > 1:
            b = m.reshape(-1, skip)
            a = b[:, 0].reshape(-1, 1)
            b = b[:, 1:]
            aseen = seen.reshape(-1, skip)
            eq = ((a == b) | ((a != a) & (b != b))).all(1) & (~aseen).all(1)
        else:
            eq = ~seen
        for ipix in np.flatnonzero(eq):
            ipix0 = ipix * skip
            ipix1 = (ipix + 1) * skip
            seen[ipix0:ipix1] = True
            if extra:
                yield _HEALPixTreeVisitExtra(
                    nside, max_nside, ipix, ipix0, ipix1, m[ipix0])
            else:
                yield _HEALPixTreeVisit(nside, ipix)


def _reconstruct_nested_depthfirst(m, extra):
    result = sorted(
        _reconstruct_nested_breadthfirst(m, True),
        key=lambda _: _.ipix0)
    if not extra:
        result = (_HEALPixTreeVisit(_.nside, _.ipix) for _ in result)
    return result


def reconstruct_nested(m, order='depthfirst', extra=True):
    """Reconstruct the leaves of a multiresolution tree.

    Parameters
    ----------
    m : `~numpy.ndarray`
        A HEALPix array in the NESTED ordering scheme.

    order : {'depthfirst', 'breadthfirst'}, optional
        Traversal order: 'depthfirst' (the default) or 'breadthfirst'.

    extra : bool
        Whether to output extra information about the pixel (default is True).

    Yields
    ------
    nside : int
        The HEALPix resolution of the node.

    full_nside : int, present if extra=True
        The HEALPix resolution of the deepest node in the tree.

    ipix : int
        The nested HEALPix index of the node.

    ipix0 : int, present if extra=True
        The start index of the range of pixels spanned by the node at the
        resolution `full_nside`.

    ipix1 : int, present if extra=True
        The end index of the range of pixels spanned by the node at the
        resolution `full_nside`.

    value : list, present if extra=True
        The value of the map at the node.

    Examples
    --------

    An nside=1 array of all zeros:

    >>> m = np.zeros(12)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=1 array of distinct values:

    >>> m = range(12)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=8 array of zeros:

    >>> m = np.zeros(768)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=2 array, all zeros except for four consecutive distinct elements:

    >>> m = np.zeros(48); m[:4] = range(4)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3)]

    Same, but in depthfirst order:

    >>> result = reconstruct_nested(m, order='depthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(2, 0), (2, 1), (2, 2), (2, 3), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=2 array, all elements distinct except for four consecutive zeros:

    >>> m = np.arange(48); m[:4] = 0
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11), (2, 12), (2, 13), (2, 14), (2, 15), (2, 16), (2, 17), (2, 18), (2, 19), (2, 20), (2, 21), (2, 22), (2, 23), (2, 24), (2, 25), (2, 26), (2, 27), (2, 28), (2, 29), (2, 30), (2, 31), (2, 32), (2, 33), (2, 34), (2, 35), (2, 36), (2, 37), (2, 38), (2, 39), (2, 40), (2, 41), (2, 42), (2, 43), (2, 44), (2, 45), (2, 46), (2, 47)]
    """
    funcs = {'depthfirst': _reconstruct_nested_depthfirst,
             'breadthfirst': _reconstruct_nested_breadthfirst}
    func = funcs[order]
    yield from func(m, extra)

ligo/skymap/kde.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
#
# Copyright (C) 2012-2020  Will M. Farr <will.farr@ligo.org>
#                          Leo P. Singer <leo.singer@ligo.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import copyreg
from functools import partial

from astropy.coordinates import SkyCoord
from astropy.utils.misc import NumpyRNGContext
import healpy as hp
import logging
import numpy as np
from scipy.stats import gaussian_kde

from . import distance
from . import moc
from .coordinates import EigenFrame
from .util import progress_map

log = logging.getLogger()

__all__ = ('BoundedKDE', 'Clustered2DSkyKDE', 'Clustered3DSkyKDE',
           'Clustered2Plus1DSkyKDE')


class BoundedKDE(gaussian_kde):
    """Density estimation using a KDE on bounded domains.

    Bounds can be any combination of low or high (if no bound, set to
    ``float('inf')`` or ``float('-inf')``), and can be periodic or
    non-periodic.  Cannot handle topologies that have
    multi-dimensional periodicities; will only handle topologies that
    are direct products of (arbitrary numbers of) R, [0,1], and S1.

    Parameters
    ----------
    pts : :class:`numpy.ndarray`
        ``(Ndim, Npts)`` shaped array of points (as in :class:`gaussian_kde`).
    low
        Lower bounds; if ``None``, assume no lower bounds.
    high
        Upper bounds; if ``None``, assume no upper bounds.
    periodic
        Boolean array giving periodicity in each dimension; if
        ``None`` assume no dimension is periodic.
    bw_method : optional
        Bandwidth estimation method (see :class:`gaussian_kde`).

    """

    def __init__(self, pts, low=-np.inf, high=np.inf, periodic=False,
                 bw_method=None):

        super().__init__(pts, bw_method=bw_method)
        self._low = np.broadcast_to(
            low, self.d).astype(self.dataset.dtype)
        self._high = np.broadcast_to(
            high, self.d).astype(self.dataset.dtype)
        self._periodic = np.broadcast_to(
            periodic, self.d).astype(bool)

    def evaluate(self, pts):
        """Evaluate the KDE at the given points."""
        pts = np.atleast_2d(pts)
        d, m = pts.shape
        if d != self.d and d == 1 and m == self.d:
            pts = pts.T

        pts_orig = pts
        pts = np.copy(pts_orig)

        den = super().evaluate(pts)

        for i, (low, high, period) in enumerate(zip(self._low, self._high,
                                                    self._periodic)):
            if period:
                p = high - low

                pts[i, :] += p
                den += super().evaluate(pts)

                pts[i, :] -= 2.0 * p
                den += super().evaluate(pts)

                pts[i, :] = pts_orig[i, :]

            else:
                if not np.isneginf(low):
                    pts[i, :] = 2.0 * low - pts[i, :]
                    den += super().evaluate(pts)
                    pts[i, :] = pts_orig[i, :]

                if not np.isposinf(high):
                    pts[i, :] = 2.0 * high - pts[i, :]
                    den += super().evaluate(pts)
                    pts[i, :] = pts_orig[i, :]

        return den

    __call__ = evaluate

    def quantile(self, pt):
        """Quantile of ``pt``, evaluated by a greedy algorithm.

        Parameters
        ----------
        pt
            The point at which the quantile value is to be computed.

        Notes
        -----
        The quantile of ``pt`` is the fraction of points used to construct the
        KDE that have a lower KDE density than ``pt``.

        """
        return np.count_nonzero(self(self.dataset) < self(pt)) / self.n


def km_assign(mus, cov, pts):
    """Implement the assignment step in the k-means algorithm.

    Given a set of centers, ``mus``, a covariance matrix used to produce a
    metric on the space, ``cov``, and a set of points, ``pts`` (shape ``(npts,
    ndim)``), assigns each point to its nearest center, returning an array of
    indices of shape ``(npts,)`` giving the assignments.
    """
    k = mus.shape[0]
    n = pts.shape[0]

    dists = np.zeros((k, n))

    for i, mu in enumerate(mus):
        dx = pts - mu
        try:
            dists[i, :] = np.sum(dx * np.linalg.solve(cov, dx.T).T, axis=1)
        except np.linalg.LinAlgError:
            dists[i, :] = np.nan

    return np.nanargmin(dists, axis=0)


def km_centroids(pts, assign, k):
    """Implement the centroid-update step of the k-means algorithm.

    Given a set of points, ``pts``, of shape ``(npts, ndim)``, and an
    assignment of each point to a region, ``assign``, and the number of means,
    ``k``, returns an array of shape ``(k, ndim)`` giving the centroid of each
    region.
    """
    mus = np.zeros((k, pts.shape[1]))
    for i in range(k):
        sel = assign == i
        if np.sum(sel) > 0:
            mus[i, :] = np.mean(pts[sel, :], axis=0)
        else:
            mus[i, :] = pts[np.random.randint(pts.shape[0]), :]

    return mus


def k_means(pts, k):
    """Perform k-means clustering on the set of points.

    Parameters
    ----------
    pts
        Array of shape ``(npts, ndim)`` giving the points on which k-means is
        to operate.
    k
        Positive integer giving the number of regions.

    Returns
    -------
    centroids
        An ``(k, ndim)`` array giving the centroid of each region.
    assign
        An ``(npts,)`` array of integers between 0 (inclusive) and k
        (exclusive) indicating the assignment of each point to a region.

    """
    assert pts.shape[0] > k, 'must have more points than means'

    cov = np.cov(pts, rowvar=0)

    mus = np.random.permutation(pts)[:k, :]
    assign = km_assign(mus, cov, pts)
    while True:
        old_assign = assign

        mus = km_centroids(pts, assign, k)
        assign = km_assign(mus, cov, pts)

        if np.all(assign == old_assign):
            break

    return mus, assign


def _cluster(cls, pts, trials, i, seed, jobs):
    k = i // trials
    if k == 0:
        raise ValueError('Expected at least one cluster')
    try:
        if k == 1:
            assign = np.zeros(len(pts), dtype=np.intp)
        else:
            with NumpyRNGContext(i + seed):
                _, assign = k_means(pts, k)
        obj = cls(pts, assign=assign)
    except np.linalg.LinAlgError:
        return -np.inf,
    else:
        return obj.bic, k, obj.kdes


class ClusteredKDE:

    def __init__(self, pts, max_k=40, trials=5, assign=None, jobs=1):
        self.jobs = jobs
        if assign is None:
            log.info('clustering ...')
            # Make sure that each thread gets a different random number state.
            # We start by drawing a random integer s in the main thread, and
            # then the i'th subprocess will seed itself with the integer i + s.
            #
            # The seed must be an unsigned 32-bit integer, so if there are n
            # threads, then s must be drawn from the interval [0, 2**32 - n).
            seed = np.random.randint(0, 2**32 - max_k * trials)
            func = partial(_cluster, type(self), pts, trials, seed=seed,
                           jobs=jobs)
            self.bic, self.k, self.kdes = max(
                self._map(func, range(trials, (max_k + 1) * trials)),
                key=lambda items: items[:2])
        else:
            # Build KDEs for each cluster, skipping degenerate clusters
            self.kdes = []
            npts, ndim = pts.shape
            self.k = assign.max() + 1
            for i in range(self.k):
                sel = (assign == i)
                cluster_pts = pts[sel, :]
                # Equivalent to but faster than len(set(pts))
                nuniq = len(np.unique(cluster_pts, axis=0))
                # Skip if there are fewer unique points than dimensions
                if nuniq <= ndim:
                    continue
                try:
                    kde = gaussian_kde(cluster_pts.T)
                except (np.linalg.LinAlgError, ValueError):
                    # If there are fewer unique points than degrees of freedom,
                    # then the KDE will fail because the covariance matrix is
                    # singular. In that case, don't bother adding that cluster.
                    pass
                else:
                    self.kdes.append(kde)

            # Calculate BIC
            # The number of parameters is:
            #
            # * ndim for each centroid location
            #
            # * (ndim+1)*ndim/2 Kernel covariances for each cluster
            #
            # * one weighting factor for the cluster (minus one for the
            #   overall constraint that the weights must sum to one)
            nparams = (self.k * ndim +
                       0.5 * self.k * (ndim + 1) * ndim + self.k - 1)
            with np.errstate(divide='ignore'):
                self.bic = (
                    np.sum(np.log(self.eval_kdes(pts))) -
                    0.5 * nparams * np.log(npts))

    def eval_kdes(self, pts):
        pts = pts.T
        return sum(w * kde(pts) for w, kde in zip(self.weights, self.kdes))

    def __call__(self, pts):
        return self.eval_kdes(pts)

    @property
    def weights(self):
        """Get the cluster weights: the fraction of the points within each
        cluster.
        """
        w = np.asarray([kde.n for kde in self.kdes])
        return w / np.sum(w)

    def _map(self, func, items):
        return progress_map(func, items, jobs=self.jobs)


class SkyKDE(ClusteredKDE):

    @classmethod
    def transform(cls, pts):
        """Override in sub-classes to transform points."""
        raise NotImplementedError

    def __init__(self, pts, max_k=40, trials=5, assign=None, jobs=1):
        if assign is None:
            pts = self.transform(pts)
        super().__init__(
            pts, max_k=max_k, trials=trials, assign=assign, jobs=jobs)

    def __call__(self, pts):
        return super().__call__(self.transform(pts))

    def as_healpix(self, top_nside=16, rounds=8):
        return moc.bayestar_adaptive_grid(self, top_nside=top_nside,
                                          rounds=rounds)


# We have to put in some hooks to make instances of Clustered2DSkyKDE picklable
# because we dynamically create subclasses with different values of the 'frame'
# class variable. This gets even trickier because we need both the class and
# instance objects to be picklable.


class _Clustered2DSkyKDEMeta(type):  # noqa: N802
    """Metaclass to make dynamically created subclasses of Clustered2DSkyKDE
    picklable.
    """


def _Clustered2DSkyKDEMeta_pickle(cls):  # noqa: N802
    """Pickle dynamically created subclasses of Clustered2DSkyKDE."""
    return type, (cls.__name__, cls.__bases__, {'frame': cls.frame})


# Register function to pickle subclasses of Clustered2DSkyKDE.
copyreg.pickle(_Clustered2DSkyKDEMeta, _Clustered2DSkyKDEMeta_pickle)


def _Clustered2DSkyKDE_factory(name, frame):  # noqa: N802
    """Unpickle instances of dynamically created subclasses of
    Clustered2DSkyKDE.

    FIXME: In Python 3, we could make this a class method of Clustered2DSkyKDE.
    Unfortunately, Python 2 is picky about pickling bound class methods.
    """
    new_cls = type(name, (Clustered2DSkyKDE,), {'frame': frame})
    return super(Clustered2DSkyKDE, Clustered2DSkyKDE).__new__(new_cls)


class Clustered2DSkyKDE(SkyKDE, metaclass=_Clustered2DSkyKDEMeta):
    r"""Represents a kernel-density estimate of a sky-position PDF that has
    been decomposed into clusters, using a different kernel for each
    cluster.

    The estimated PDF is

    .. math::

      p\left( \vec{\theta} \right) = \sum_{i = 0}^{k-1} \frac{N_i}{N}
      \sum_{\vec{x} \in C_i} N\left[\vec{x}, \Sigma_i\right]\left( \vec{\theta}
      \right)

    where :math:`C_i` is the set of points belonging to cluster
    :math:`i`, :math:`N_i` is the number of points in this cluster,
    :math:`\Sigma_i` is the optimally-converging KDE covariance
    associated to cluster :math:`i`.

    The number of clusters, :math:`k` is chosen to maximize the `BIC
    <http://en.wikipedia.org/wiki/Bayesian_information_criterion>`_
    for the given set of points being drawn from the clustered KDE.
    The points are assigned to clusters using the k-means algorithm,
    with a decorrelated metric.  The overall clustering behavior is
    similar to the well-known `X-Means
    <http://www.cs.cmu.edu/~dpelleg/download/xmeans.pdf>`_ algorithm.
    """

    frame = None

    @classmethod
    def transform(cls, pts):
        pts = SkyCoord(*pts.T, unit='rad').transform_to(cls.frame).spherical
        return np.column_stack((pts.lon.rad, np.sin(pts.lat.rad)))

    def __new__(cls, pts, *args, **kwargs):
        frame = EigenFrame.for_coords(SkyCoord(*pts.T, unit='rad'))
        name = '{:s}_{:x}'.format(cls.__name__, id(frame))
        new_cls = type(name, (cls,), {'frame': frame})
        return super().__new__(new_cls)

    def __reduce__(self):
        """Pickle instances of dynamically created subclasses of
        Clustered2DSkyKDE.
        """
        factory_args = self.__class__.__name__, self.frame
        return _Clustered2DSkyKDE_factory, factory_args, self.__dict__

    def eval_kdes(self, pts):
        base = super().eval_kdes
        dphis = (0.0, 2 * np.pi, -2 * np.pi)
        phi, z = pts.T
        return sum(base(np.column_stack((phi + dphi, z))) for dphi in dphis)


class Clustered3DSkyKDE(SkyKDE):
    """Like :class:`Clustered2DSkyKDE`, but clusters in 3D
    space.  Can compute volumetric posterior density (per cubic Mpc),
    and also produce Healpix maps of the mean and standard deviation
    of the log-distance.
    """

    @classmethod
    def transform(cls, pts):
        return SkyCoord(*pts.T, unit='rad').cartesian.xyz.value.T

    def __call__(self, pts, distances=False):
        """Given an array of positions in RA, DEC, compute the marginal sky
        posterior and optinally the conditional distance parameters.
        """
        func = partial(distance.cartesian_kde_to_moments,
                       datasets=[_.dataset for _ in self.kdes],
                       inverse_covariances=[_.inv_cov for _ in self.kdes],
                       weights=self.weights)
        probdensity, mean, std = zip(*self._map(func, self.transform(pts)))
        if distances:
            mu, sigma, norm = distance.moments_to_parameters(mean, std)
            return probdensity, mu, sigma, norm
        else:
            return probdensity

    def posterior_spherical(self, pts):
        """Evaluate the posterior probability density in spherical polar
        coordinates, as a function of (ra, dec, distance).
        """
        return super().__call__(pts)

    def as_healpix(self, top_nside=16):
        """Return a HEALPix multi-order map of the posterior density and
        conditional distance distribution parameters.
        """
        m = super().as_healpix(top_nside=top_nside)
        order, ipix = moc.uniq2nest(m['UNIQ'])
        nside = 2 ** order.astype(int)
        theta, phi = hp.pix2ang(nside, ipix, nest=True)
        p = np.column_stack((phi, 0.5 * np.pi - theta))
        print('evaluating distance layers ...')
        _, m['DISTMU'], m['DISTSIGMA'], m['DISTNORM'] = self(p, distances=True)
        return m


class Clustered2Plus1DSkyKDE(Clustered3DSkyKDE):
    """A hybrid sky map estimator that uses a 2D clustered KDE for the marginal
    distribution as a function of (RA, Dec) and a 3D clustered KDE for the
    conditional distance distribution.
    """

    def __init__(self, pts, max_k=40, trials=5, assign=None, jobs=1):
        if assign is None:
            self.twod = Clustered2DSkyKDE(
                pts, max_k=max_k, trials=trials, assign=assign, jobs=jobs)
        super().__init__(
            pts, max_k=max_k, trials=trials, assign=assign, jobs=jobs)

    def __call__(self, pts, distances=False):
        probdensity = self.twod(pts)
        if distances:
            _, distmu, distsigma, distnorm = super().__call__(
                pts, distances=True)
            return probdensity, distmu, distsigma, distnorm
        else:
            return probdensity

    def posterior_spherical(self, pts):
        """Evaluate the posterior probability density in spherical polar
        coordinates, as a function of (ra, dec, distance).
        """
        return self(pts) * super().posterior_spherical(pts) / super().__call__(
            pts)

ligo/skymap/moc.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
#
# Copyright (C) 2017-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Support for HEALPix UNIQ pixel indexing [1]_ and multi-order coverage (MOC)
maps [2]_.

References
----------
.. [1] Reinecke & Hivon, 2015. "Efficient data structures for masks on 2D
       grids." AA 580, A132. :doi:`10.1051/0004-6361/201526549`
.. [2] Boch et al., 2014. "MOC - HEALPix Multi-Order Coverage map." IVOA
       Recommendation <http://ivoa.net/documents/MOC/>.

"""

from astropy import table
from astropy import units as u
import astropy_healpix as ah
import numpy as np
from numpy.lib.recfunctions import repack_fields

from .core import nest2uniq, uniq2nest, uniq2order, uniq2pixarea, uniq2ang
from .core import rasterize as _rasterize
from .util.numpy import add_newdoc_ufunc

__all__ = ('nest2uniq', 'uniq2nest', 'uniq2order', 'uniq2pixarea',
           'uniq2ang', 'rasterize', 'bayestar_adaptive_grid')


add_newdoc_ufunc(nest2uniq, """\
Convert a pixel index from NESTED to NUNIQ ordering.

Parameters
----------
order : `numpy.ndarray`
    HEALPix resolution order, the logarithm base 2 of `nside`
ipix : `numpy.ndarray`
    NESTED pixel index

Returns
-------
uniq : `numpy.ndarray`
    NUNIQ pixel index

""")


add_newdoc_ufunc(uniq2order, """\
Determine the HEALPix resolution order of a HEALPix NESTED index.

Parameters
----------
uniq : `numpy.ndarray`
    NUNIQ pixel index

Returns
-------
order : `numpy.ndarray`
    HEALPix resolution order, the logarithm base 2 of `nside`

""")


add_newdoc_ufunc(uniq2pixarea, """\
Determine the area of a HEALPix NESTED index.

Parameters
----------
uniq : `numpy.ndarray`
    NUNIQ pixel index

Returns
-------
area : `numpy.ndarray`
    The pixel's area in steradians

""")


add_newdoc_ufunc(uniq2nest, """\
Convert a pixel index from NUNIQ to NESTED ordering.

Parameters
----------
uniq : `numpy.ndarray`
    NUNIQ pixel index

Returns
-------
order : `numpy.ndarray`
    HEALPix resolution order (logarithm base 2 of `nside`)
ipix : `numpy.ndarray`
    NESTED pixel index

""")


def rasterize(moc_data, order=None):
    """Convert a multi-order HEALPix dataset to fixed-order NESTED ordering.

    Parameters
    ----------
    moc_data : `numpy.ndarray`
        A multi-order HEALPix dataset stored as a Numpy record array whose
        first column is called UNIQ and contains the NUNIQ pixel index. Every
        point on the unit sphere must be contained in exactly one pixel in the
        dataset.
    order : int, optional
        The desired output resolution order, or :obj:`None` for the maximum
        resolution present in the dataset.

    Returns
    -------
    nested_data : `numpy.ndarray`
        A fixed-order, NESTED-ordering HEALPix dataset with all of the columns
        that were in moc_data, with the exception of the UNIQ column.

    """
    if np.ndim(moc_data) != 1:
        raise ValueError('expected 1D structured array or Astropy table')
    elif order is None or order < 0:
        order = -1
    else:
        orig_order, orig_nest = uniq2nest(moc_data['UNIQ'])
        to_downsample = order < orig_order
        if np.any(to_downsample):
            to_keep = table.Table(moc_data[~to_downsample], copy=False)
            orig_order = orig_order[to_downsample]
            orig_nest = orig_nest[to_downsample]
            to_downsample = table.Table(moc_data[to_downsample], copy=False)

            ratio = 1 << (2 * np.int64(orig_order - order))
            weights = 1.0 / ratio
            for colname, column in to_downsample.columns.items():
                if colname != 'UNIQ':
                    column *= weights
            to_downsample['UNIQ'] = nest2uniq(order, orig_nest // ratio)
            to_downsample = to_downsample.group_by(
                'UNIQ').groups.aggregate(np.sum)

            moc_data = table.vstack((to_keep, to_downsample))

    # Ensure that moc_data has appropriate padding for each of its columns to
    # be properly aligned in order to avoid undefined behavior.
    moc_data = repack_fields(np.asarray(moc_data), align=True)

    return _rasterize(moc_data, order=order)


def bayestar_adaptive_grid(probdensity, *args, top_nside=16, rounds=8,
                           **kwargs):
    """Create a sky map by evaluating a function on an adaptive grid.

    Perform the BAYESTAR adaptive mesh refinement scheme as described in
    Section VI of Singer & Price 2016, PRD, 93, 024013
    :doi:`10.1103/PhysRevD.93.024013`. This computes the sky map
    using a provided analytic function and refines the grid, dividing the
    highest 25% into subpixels and then recalculating their values. The extra
    given args and kwargs will be passed to the given probdensity function.

    Parameters
    ----------
    probdensity : callable
        Probability density function. The first argument consists of
        column-stacked array of right ascension and declination in radians.
        The return value must be a 1D array of the probability density in
        inverse steradians with the same length as the argument.
    top_nside : int
        HEALPix NSIDE resolution of initial evaluation of the sky map
    rounds : int
        Number of refinement rounds, including the initial sky map evaluation

    Returns
    -------
    skymap : astropy.table.Table
        An astropy Table with UNIQ and PROBDENSITY columns, representing
        a multi-ordered sky map
    """
    top_npix = ah.nside_to_npix(top_nside)
    nrefine = top_npix // 4
    cells = zip([0] * nrefine, [top_nside // 2] * nrefine, range(nrefine))
    for iround in range(rounds + 1):
        print('adaptive refinement round {} of {} ...'.format(
            iround, rounds))
        cells = sorted(cells, key=lambda p_n_i: p_n_i[0] / p_n_i[1]**2)
        new_nside, new_ipix = np.transpose([
            (nside * 2, ipix * 4 + i)
            for _, nside, ipix in cells[-nrefine:] for i in range(4)])
        ra, dec = ah.healpix_to_lonlat(new_ipix, new_nside, order='nested')
        p = probdensity(np.column_stack((ra.value, dec.value)),
                        *args, **kwargs)
        cells[-nrefine:] = zip(p, new_nside, new_ipix)

    """Return a HEALPix multi-order map of the posterior density."""
    post, nside, ipix = zip(*cells)
    post = np.asarray(list(post))
    nside = np.asarray(list(nside))
    ipix = np.asarray(list(ipix))

    # Make sure that sky map is normalized (it should be already)
    post /= np.sum(post * ah.nside_to_pixel_area(nside).to_value(u.sr))

    # Convert from NESTED to UNIQ pixel indices
    order = np.log2(nside).astype(int)
    uniq = nest2uniq(order.astype(np.int8), ipix)

    # Done!
    return table.Table([uniq, post], names=['UNIQ', 'PROBDENSITY'],
                       copy=False)


del add_newdoc_ufunc

ligo/skymap/bayestar/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
#
# Copyright (C) 2013-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Rapid sky localization with BAYESTAR [1]_.

References
----------
.. [1] Singer & Price, 2016. "Rapid Bayesian position reconstruction for
       gravitational-wave transients." PRD, 93, 024013.
       :doi:`10.1103/PhysRevD.93.024013`

"""
import inspect
import logging
import os
import sys
from textwrap import wrap

from astropy.table import Column, Table
from astropy import units as u
import lal
import lalsimulation
import numpy as np

from .. import distance
from . import filter
from ..io.hdf5 import write_samples
from ..io.fits import metadata_for_version_module
from ..io.events.base import Event
from . import filter  # noqa
from ..kde import Clustered2Plus1DSkyKDE
from .. import moc
from .. import healpix_tree
from .. import version
from .. import core
from ..core import (antenna_factor, signal_amplitude_model,
                    log_posterior_toa_phoa_snr as _log_posterior_toa_phoa_snr)
from ..util.numpy import require_contiguous_aligned
from ..util.stopwatch import Stopwatch
from .ez_emcee import ez_emcee

__all__ = ('derasterize', 'localize', 'rasterize', 'antenna_factor',
           'signal_amplitude_model')

log = logging.getLogger('BAYESTAR')

_RESCALE_LOGLIKELIHOOD = 0.83

signal_amplitude_model = require_contiguous_aligned(signal_amplitude_model)
_log_posterior_toa_phoa_snr = require_contiguous_aligned(
    _log_posterior_toa_phoa_snr)


# Wrap so that ufunc parameter names are known
def log_posterior_toa_phoa_snr(
        ra, sin_dec, distance, u, twopsi, t, min_distance, max_distance,
        prior_distance_power, cosmology, gmst, sample_rate, epochs, snrs,
        responses, locations, horizons,
        rescale_loglikelihood=_RESCALE_LOGLIKELIHOOD):
    return _log_posterior_toa_phoa_snr(
        ra, sin_dec, distance, u, twopsi, t, min_distance, max_distance,
        prior_distance_power, cosmology, gmst, sample_rate, epochs, snrs,
        responses, locations, horizons, rescale_loglikelihood)


# Wrap so that fixed parameter values are pulled from keyword arguments
def log_post(params, *args, **kwargs):
    # Names of parameters
    keys = ('ra', 'sin_dec', 'distance', 'u', 'twopsi', 't')

    params = list(params.T)
    kwargs = dict(kwargs)

    return log_posterior_toa_phoa_snr(
        *(kwargs.pop(key) if key in kwargs else params.pop(0) for key in keys),
        *args, **kwargs)


def localize_emcee(args, xmin, xmax, chain_dump=None):
    # Gather posterior samples
    chain = ez_emcee(log_post, xmin, xmax, args=args, vectorize=True)

    # Transform back from sin_dec to dec and cos_inclination to inclination
    chain[:, 1] = np.arcsin(chain[:, 1])
    chain[:, 3] = np.arccos(chain[:, 3])

    # Optionally save posterior sample chain to file.
    if chain_dump:
        _, ndim = chain.shape
        names = 'ra dec distance inclination twopsi time'.split()[:ndim]
        write_samples(Table(rows=chain, names=names, copy=False), chain_dump,
                      path='/bayestar/posterior_samples', overwrite=True)

    # Pass a random subset of 1000 points to the KDE, to save time.
    pts = np.random.permutation(chain)[:1000, :3]
    ckde = Clustered2Plus1DSkyKDE(pts)
    return ckde.as_healpix()


def condition(
        event, waveform='o2-uberbank', f_low=30.0,
        enable_snr_series=True, f_high_truncate=0.95):

    if len(event.singles) == 0:
        raise ValueError('Cannot localize an event with zero detectors.')

    singles = event.singles
    if not enable_snr_series:
        singles = [single for single in singles if single.snr is not None]

    ifos = [single.detector for single in singles]

    # Extract SNRs from table.
    snrs = np.ma.asarray([
        np.ma.masked if single.snr is None else single.snr
        for single in singles])

    # Look up physical parameters for detector.
    detectors = [lalsimulation.DetectorPrefixToLALDetector(str(ifo))
                 for ifo in ifos]
    responses = np.asarray([det.response for det in detectors])
    locations = np.asarray([det.location for det in detectors]) / lal.C_SI

    # Power spectra for each detector.
    psds = [single.psd for single in singles]
    psds = [filter.InterpolatedPSD(filter.abscissa(psd), psd.data.data,
                                   f_high_truncate=f_high_truncate)
            for psd in psds]

    log.debug('calculating templates')
    H = filter.sngl_inspiral_psd(waveform, f_min=f_low, **event.template_args)

    log.debug('calculating noise PSDs')
    HS = [filter.signal_psd_series(H, S) for S in psds]

    # Signal models for each detector.
    log.debug('calculating Fisher matrix elements')
    signal_models = [filter.SignalModel(_) for _ in HS]

    # Get SNR=1 horizon distances for each detector.
    horizons = np.asarray([signal_model.get_horizon_distance()
                           for signal_model in signal_models])

    weights = np.ma.asarray([
        1 / np.square(signal_model.get_crb_toa_uncert(snr))
        for signal_model, snr in zip(signal_models, snrs)])

    # Center detector array.
    locations -= (np.sum(locations * weights.reshape(-1, 1), axis=0) /
                  np.sum(weights))

    if enable_snr_series:
        snr_series = [single.snr_series for single in singles]
        if all(s is None for s in snr_series):
            snr_series = None
    else:
        snr_series = None

    # Maximum barycentered arrival time error:
    # |distance from array barycenter to furthest detector| / c + 5 ms.
    # For LHO+LLO, this is 15.0 ms.
    # For an arbitrary terrestrial detector network, the maximum is 26.3 ms.
    max_abs_t = np.max(
        np.sqrt(np.sum(np.square(locations), axis=1))) + 0.005

    # Calculate a minimum arrival time prior for low-bandwidth signals.
    # This is important for early-warning events, for which the light travel
    # time between detectors may be only a tiny fraction of the template
    # autocorrelation time.
    #
    # The period of the autocorrelation function is 2 pi times the timing
    # uncertainty at SNR=1 (see Eq. (24) of [1]_). We want a quarter period:
    # there is a factor of a half because we want the first zero crossing,
    # and another factor of a half because the SNR time series will be cropped
    # to a two-sided interval.
    max_acor_t = max(0.5 * np.pi * signal_model.get_crb_toa_uncert(1)
                     for signal_model in signal_models)
    max_abs_t = max(max_abs_t, max_acor_t)

    if snr_series is None:
        log.warning("No SNR time series found, so we are creating a "
                    "zero-noise SNR time series from the whitened template's "
                    "autocorrelation sequence. The sky localization "
                    "uncertainty may be underestimated.")

        acors, sample_rates = zip(
            *[filter.autocorrelation(_, max_abs_t) for _ in HS])
        sample_rate = sample_rates[0]
        deltaT = 1 / sample_rate
        nsamples = len(acors[0])
        assert all(sample_rate == _ for _ in sample_rates)
        assert all(nsamples == len(_) for _ in acors)
        nsamples = nsamples * 2 - 1

        snr_series = []
        for acor, single in zip(acors, singles):
            series = lal.CreateCOMPLEX8TimeSeries(
                'fake SNR', 0, 0, deltaT, lal.StrainUnit, nsamples)
            series.epoch = single.time - 0.5 * (nsamples - 1) * deltaT
            acor = np.concatenate((np.conj(acor[:0:-1]), acor))
            series.data.data = single.snr * filter.exp_i(single.phase) * acor
            snr_series.append(series)

    # Ensure that all of the SNR time series have the same sample rate.
    # FIXME: for now, the Python wrapper expects all of the SNR time series to
    # also be the same length.
    deltaT = snr_series[0].deltaT
    sample_rate = 1 / deltaT
    if any(deltaT != series.deltaT for series in snr_series):
        raise ValueError('BAYESTAR does not yet support SNR time series with '
                         'mixed sample rates')

    # Ensure that all of the SNR time series have odd lengths.
    if any(len(series.data.data) % 2 == 0 for series in snr_series):
        raise ValueError('SNR time series must have odd lengths')

    # Trim time series to the desired length.
    max_abs_n = int(np.ceil(max_abs_t * sample_rate))
    desired_length = 2 * max_abs_n - 1
    for i, series in enumerate(snr_series):
        length = len(series.data.data)
        if length > desired_length:
            snr_series[i] = lal.CutCOMPLEX8TimeSeries(
                series, length // 2 + 1 - max_abs_n, desired_length)

    # FIXME: for now, the Python wrapper expects all of the SNR time sries to
    # also be the same length.
    nsamples = len(snr_series[0].data.data)
    if any(nsamples != len(series.data.data) for series in snr_series):
        raise ValueError('BAYESTAR does not yet support SNR time series of '
                         'mixed lengths')

    # Perform sanity checks that the middle sample of the SNR time series match
    # the sngl_inspiral records to the nearest sample (plus the smallest
    # representable LIGOTimeGPS difference of 1 nanosecond).
    for ifo, single, series in zip(ifos, singles, snr_series):
        shift = np.abs(0.5 * (nsamples - 1) * series.deltaT +
                       float(series.epoch - single.time))
        if shift >= deltaT + 1e-8:
            raise ValueError('BAYESTAR expects the SNR time series to be '
                             'centered on the single-detector trigger times, '
                             'but {} was off by {} s'.format(ifo, shift))

    # Extract the TOAs in GPS nanoseconds from the SNR time series, assuming
    # that the trigger happened in the middle.
    toas_ns = [series.epoch.ns() + 1e9 * 0.5 * (len(series.data.data) - 1) *
               series.deltaT for series in snr_series]

    # Collect all of the SNR series in one array.
    snr_series = np.vstack([series.data.data for series in snr_series])

    # Center times of arrival and compute GMST at mean arrival time.
    # Pre-center in integer nanoseconds to preserve precision of
    # initial datatype.
    epoch = sum(toas_ns) // len(toas_ns)
    toas = 1e-9 * (np.asarray(toas_ns) - epoch)
    mean_toa = np.average(toas, weights=weights)
    toas -= mean_toa
    epoch += int(np.round(1e9 * mean_toa))
    epoch = lal.LIGOTimeGPS(0, int(epoch))

    # Translate SNR time series back to time of first sample.
    toas -= 0.5 * (nsamples - 1) * deltaT

    # Convert complex SNRS to amplitude and phase
    snrs_abs = np.abs(snr_series)
    snrs_arg = filter.unwrap(np.angle(snr_series))
    snrs = np.stack((snrs_abs, snrs_arg), axis=-1)

    return epoch, sample_rate, toas, snrs, responses, locations, horizons


def condition_prior(horizons, min_distance=None, max_distance=None,
                    prior_distance_power=None, cosmology=False):
    if cosmology:
        log.warning('Enabling cosmological prior. This feature is UNREVIEWED.')

    # If minimum distance is not specified, then default to 0 Mpc.
    if min_distance is None:
        min_distance = 0

    # If maximum distance is not specified, then default to the SNR=4
    # horizon distance of the most sensitive detector.
    if max_distance is None:
        max_distance = max(horizons) / 4

    # If prior_distance_power is not specified, then default to 2
    # (p(r) ~ r^2, uniform in volume).
    if prior_distance_power is None:
        prior_distance_power = 2

    # Raise an exception if 0 Mpc is the minimum effective distance and the
    # prior is of the form r**k for k<0
    if min_distance == 0 and prior_distance_power < 0:
        raise ValueError(('Prior is a power law r^k with k={}, '
                          'undefined at min_distance=0').format(
                              prior_distance_power))

    return min_distance, max_distance, prior_distance_power, cosmology


def localize(
        event, waveform='o2-uberbank', f_low=30.0,
        min_distance=None, max_distance=None, prior_distance_power=None,
        cosmology=False, mcmc=False, chain_dump=None,
        enable_snr_series=True, f_high_truncate=0.95,
        rescale_loglikelihood=_RESCALE_LOGLIKELIHOOD):
    """Localize a compact binary signal using the BAYESTAR algorithm.

    Parameters
    ----------
    event : `ligo.skymap.io.events.Event`
        The event candidate.
    waveform : str, optional
        The name of the waveform approximant.
    f_low : float, optional
        The low frequency cutoff.
    min_distance, max_distance : float, optional
        The limits of integration over luminosity distance, in Mpc
        (default: determine automatically from detector sensitivity).
    prior_distance_power : int, optional
        The power of distance that appears in the prior
        (default: 2, uniform in volume).
    cosmology: bool, optional
        Set to enable a uniform in comoving volume prior (default: false).
    mcmc : bool, optional
        Set to use MCMC sampling rather than more accurate Gaussian quadrature.
    chain_dump : str, optional
        Save posterior samples to this filename if `mcmc` is set.
    enable_snr_series : bool, optional
        Set to False to disable SNR time series.
    f_high_truncate : float, optional
        Truncate the noise power spectral densities at this factor times the
        highest sampled frequency to suppress artifacts caused by incorrect
        PSD conditioning by some matched filter pipelines.

    Returns
    -------
    skymap : `astropy.table.Table`
        A 3D sky map in multi-order HEALPix format.

    """
    # Hide event parameters, but show all other arguments
    def formatvalue(value):
        if isinstance(value, Event):
            return '=...'
        else:
            return '=' + repr(value)

    frame = inspect.currentframe()
    argstr = inspect.formatargvalues(*inspect.getargvalues(frame),
                                     formatvalue=formatvalue)

    stopwatch = Stopwatch()
    stopwatch.start()

    epoch, sample_rate, toas, snrs, responses, locations, horizons = \
        condition(event, waveform=waveform, f_low=f_low,
                  enable_snr_series=enable_snr_series,
                  f_high_truncate=f_high_truncate)

    min_distance, max_distance, prior_distance_power, cosmology = \
        condition_prior(horizons, min_distance, max_distance,
                        prior_distance_power, cosmology)

    gmst = lal.GreenwichMeanSiderealTime(epoch)

    # Time and run sky localization.
    log.debug('starting computationally-intensive section')
    args = (min_distance, max_distance, prior_distance_power, cosmology, gmst,
            sample_rate, toas, snrs, responses, locations, horizons,
            rescale_loglikelihood)
    if mcmc:
        max_abs_t = 2 * snrs.data.shape[1] / sample_rate
        skymap = localize_emcee(
            args=args,
            xmin=[0, -1, min_distance, -1, 0, 0],
            xmax=[2 * np.pi, 1, max_distance, 1, 2 * np.pi, 2 * max_abs_t],
            chain_dump=chain_dump)
    else:
        skymap, log_bci, log_bsn = core.toa_phoa_snr(*args)
        skymap = Table(skymap, copy=False)
        skymap.meta['log_bci'] = log_bci
        skymap.meta['log_bsn'] = log_bsn

    # Convert distance moments to parameters
    try:
        distmean = skymap.columns.pop('DISTMEAN')
        diststd = skymap.columns.pop('DISTSTD')
    except KeyError:
        distmean, diststd, _ = distance.parameters_to_moments(
            skymap['DISTMU'], skymap['DISTSIGMA'])
    else:
        skymap['DISTMU'], skymap['DISTSIGMA'], skymap['DISTNORM'] = \
            distance.moments_to_parameters(distmean, diststd)

    # Add marginal distance moments
    good = np.isfinite(distmean) & np.isfinite(diststd)
    prob = (moc.uniq2pixarea(skymap['UNIQ']) * skymap['PROBDENSITY'])[good]
    distmean = distmean[good]
    diststd = diststd[good]
    rbar = (prob * distmean).sum()
    r2bar = (prob * (np.square(diststd) + np.square(distmean))).sum()
    skymap.meta['distmean'] = rbar
    skymap.meta['diststd'] = np.sqrt(r2bar - np.square(rbar))

    stopwatch.stop()
    end_time = lal.GPSTimeNow()
    log.info('finished computationally-intensive section in %s', stopwatch)

    # Fill in metadata and return.
    program, _ = os.path.splitext(os.path.basename(sys.argv[0]))
    skymap.meta.update(metadata_for_version_module(version))
    skymap.meta['creator'] = 'BAYESTAR'
    skymap.meta['origin'] = 'LIGO/Virgo/KAGRA'
    skymap.meta['gps_time'] = float(epoch)
    skymap.meta['runtime'] = stopwatch.real
    skymap.meta['instruments'] = {single.detector for single in event.singles}
    skymap.meta['gps_creation_time'] = end_time
    skymap.meta['history'] = [
        '',
        'Generated by calling the following Python function:',
        *wrap('{}.{}{}'.format(__name__, frame.f_code.co_name, argstr), 72),
        '',
        'This was the command line that started the program:',
        *wrap(' '.join([program] + sys.argv[1:]), 72)]

    return skymap


def rasterize(skymap, order=None):
    orig_order, _ = moc.uniq2nest(skymap['UNIQ'].max())

    # Determine whether we need to do nontrivial downsampling.
    downsampling = (order is not None and 0 <= order < orig_order
                    and 'DISTMU' in skymap.dtype.fields.keys())

    # If we are downsampling, then convert from distance parameters to
    # distance moments times probability density so that the averaging
    # that is automatically done by moc.rasterize() correctly marginalizes
    # the moments.
    if downsampling:
        skymap = Table(skymap, copy=True, meta=skymap.meta)

        probdensity = skymap['PROBDENSITY']
        distmu = skymap.columns.pop('DISTMU')
        distsigma = skymap.columns.pop('DISTSIGMA')

        bad = ~(np.isfinite(distmu) & np.isfinite(distsigma))
        distmean, diststd, _ = distance.parameters_to_moments(
            distmu, distsigma)
        distmean[bad] = np.nan
        diststd[bad] = np.nan
        skymap['DISTMEAN'] = probdensity * distmean
        skymap['DISTVAR'] = probdensity * (
            np.square(diststd) + np.square(distmean))

    skymap = Table(moc.rasterize(skymap, order=order),
                   meta=skymap.meta, copy=False)

    # If we are downsampling, then convert back to distance parameters.
    if downsampling:
        distmean = skymap.columns.pop('DISTMEAN') / skymap['PROBDENSITY']
        diststd = np.sqrt(
            skymap.columns.pop('DISTVAR') / skymap['PROBDENSITY']
            - np.square(distmean))
        skymap['DISTMU'], skymap['DISTSIGMA'], skymap['DISTNORM'] = \
            distance.moments_to_parameters(
                distmean, diststd)

    skymap.rename_column('PROBDENSITY', 'PROB')
    skymap['PROB'] *= 4 * np.pi / len(skymap)
    skymap['PROB'].unit = u.pixel ** -1
    return skymap


def derasterize(skymap):
    skymap.rename_column('PROB', 'PROBDENSITY')
    skymap['PROBDENSITY'] *= len(skymap) / (4 * np.pi)
    skymap['PROBDENSITY'].unit = u.steradian ** -1
    nside, _, ipix, _, _, value = zip(
        *healpix_tree.reconstruct_nested(skymap))
    nside = np.asarray(nside)
    ipix = np.asarray(ipix)
    value = np.stack(value)
    uniq = (4 * np.square(nside) + ipix)
    old_units = [column.unit for column in skymap.columns.values()]
    skymap = Table(value, meta=skymap.meta, copy=False)
    for old_unit, column in zip(old_units, skymap.columns.values()):
        column.unit = old_unit
    skymap.add_column(Column(uniq, name='UNIQ'), 0)
    skymap.sort('UNIQ')
    return skymap


def test():
    """Run BAYESTAR C unit tests.

    Examples
    --------
    >>> test()
    0

    """
    return int(core.test())


del require_contiguous_aligned

ligo/skymap/bayestar/ez_emcee.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import numpy as np
from tqdm import tqdm

from .ptemcee import Sampler

__all__ = ('ez_emcee',)


def logp(x, lo, hi):
    return np.where(((x >= lo) & (x <= hi)).all(-1), 0.0, -np.inf)


def ez_emcee(log_prob_fn, lo, hi, nindep=200,
             ntemps=10, nwalkers=None, nburnin=500,
             args=(), kwargs={}, **options):
    r'''Fire-and-forget MCMC sampling using `ptemcee.Sampler`, featuring
    automated convergence monitoring, progress tracking, and thinning.

    The parameters are bounded in the finite interval described by ``lo`` and
    ``hi`` (including ``-np.inf`` and ``np.inf`` for half-infinite or infinite
    domains).

    If run in an interactive terminal, live progress is shown including the
    current sample number, the total required number of samples, time elapsed
    and estimated time remaining, acceptance fraction, and autocorrelation
    length.

    Sampling terminates when all chains have accumulated the requested number
    of independent samples.

    Parameters
    ----------
    log_prob_fn : callable
        The log probability function. It should take as its argument the
        parameter vector as an of length ``ndim``, or if it is vectorized, a 2D
        array with ``ndim`` columns.
    lo : list, `numpy.ndarray`
        List of lower limits of parameters, of length ``ndim``.
    hi : list, `numpy.ndarray`
        List of upper limits of parameters, of length ``ndim``.
    nindep : int, optional
        Minimum number of independent samples.
    ntemps : int, optional
        Number of temperatures.
    nwalkers : int, optional
        Number of walkers. The default is 4 times the number of dimensions.
    nburnin : int, optional
        Number of samples to discard during burn-in phase.

    Returns
    -------
    chain : `numpy.ndarray`
        The thinned and flattened posterior sample chain,
        with at least ``nindep`` * ``nwalkers`` rows
        and exactly ``ndim`` columns.

    Other parameters
    ----------------
    kwargs :
        Extra keyword arguments for `ptemcee.Sampler`.
        *Tip:* Consider setting the `pool` or `vectorized` keyword arguments in
        order to speed up likelihood evaluations.

    Notes
    -----
    The autocorrelation length, which has a complexity of :math:`O(N \log N)`
    in the number of samples, is recalculated at geometrically progressing
    intervals so that its amortized complexity per sample is constant. (In
    simpler terms, as the chains grow longer and the autocorrelation length
    takes longer to compute, we update it less frequently so that it is never
    more expensive than sampling the chain in the first place.)

    Examples
    --------
    >>> from ligo.skymap.bayestar.ez_emcee import ez_emcee
    >>> from matplotlib import pyplot as plt
    >>> import numpy as np
    >>>
    >>> def log_prob(params):
    ...     """Eggbox function"""
    ...     return 5 * np.log((2 + np.cos(0.5 * params).prod(-1)))
    ...
    >>> lo = [-3*np.pi, -3*np.pi]
    >>> hi = [+3*np.pi, +3*np.pi]
    >>> chain = ez_emcee(log_prob, lo, hi, vectorize=True)   # doctest: +SKIP
    Sampling:  51%|██  | 8628/16820 [00:04<00:04, 1966.74it/s, accept=0.535, acl=62]
    >>> plt.plot(chain[:, 0], chain[:, 1], '.')   # doctest: +SKIP

    .. image:: eggbox.png

    '''  # noqa: E501
    lo = np.asarray(lo)
    hi = np.asarray(hi)
    ndim = len(lo)

    if nwalkers is None:
        nwalkers = 4 * ndim

    nsteps = 64

    with tqdm(total=nburnin + nindep * nsteps) as progress:

        sampler = Sampler(nwalkers, ndim, log_prob_fn, logp,
                          ntemps=ntemps, loglargs=args, loglkwargs=kwargs,
                          logpargs=[lo, hi], random=np.random, **options)
        pos = np.random.uniform(lo, hi, (ntemps, nwalkers, ndim))

        # Burn in
        progress.set_description('Burning in')
        for pos, _, _ in sampler.sample(
                pos, iterations=nburnin, storechain=False):
            progress.update()

        sampler.reset()
        acl = np.nan
        while not np.isfinite(acl) or sampler.time < nindep * acl:

            # Advance the chain
            progress.total = nburnin + max(sampler.time + nsteps,
                                           nindep * acl)
            progress.set_description('Sampling')
            for pos, _, _ in sampler.sample(pos, iterations=nsteps):
                progress.update()

            # Refresh convergence statistics
            progress.set_description('Checking')
            acl = sampler.get_autocorr_time()[0].max()
            if np.isfinite(acl):
                acl = max(1, int(np.ceil(acl)))
            accept = np.mean(sampler.acceptance_fraction[0])
            progress.set_postfix(acl=acl, accept=accept)

            # The autocorrelation time calculation has complexity N log N in
            # the number of posterior samples. Only refresh the autocorrelation
            # length estimate on logarithmically spaced samples so that the
            # amortized complexity per sample is constant.
            nsteps *= 2

    chain = sampler.chain[0, :, ::acl, :]
    s = chain.shape
    return chain.reshape((-1, s[-1]))

ligo/skymap/bayestar/filter.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
566  
567  
568  
569  
570  
571  
572  
573  
574  
575  
576  
577  
578  
579  
580  
581  
582  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Utility functions for BAYESTAR that are related to matched filtering."""
from contextlib import contextmanager
import logging
import math

import lal
import lalsimulation
import numpy as np
from scipy import interpolate
from scipy import fftpack as fft
from scipy import linalg

log = logging.getLogger('BAYESTAR')


@contextmanager
def lal_ndebug():
    """Temporarily disable lal error messages, except for memory errors."""
    mask = ~(lal.LALERRORBIT |
             lal.LALWARNINGBIT |
             lal.LALINFOBIT |
             lal.LALTRACEBIT)
    old_level = lal.GetDebugLevel()
    lal.ClobberDebugLevel(old_level & mask)
    try:
        yield
    finally:
        lal.ClobberDebugLevel(old_level)


def unwrap(y, *args, **kwargs):
    """Unwrap phases while skipping NaN or infinite values.

    This is a simple wrapper around :meth:`numpy.unwrap` that can handle
    invalid values.

    Examples
    --------
    >>> t = np.arange(0, 2 * np.pi, 0.5)
    >>> y = np.exp(1j * t)
    >>> unwrap(np.angle(y))
    array([0. , 0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. ])
    >>> y[3] = y[4] = y[7] = np.nan
    >>> unwrap(np.angle(y))
    array([0. , 0.5, 1. , nan, nan, 2.5, 3. , nan, 4. , 4.5, 5. , 5.5, 6. ])
    """
    good = np.isfinite(y)
    result = np.empty_like(y)
    result[~good] = y[~good]
    result[good] = np.unwrap(y[good], *args, **kwargs)
    return result


def ceil_pow_2(n):
    """Return the least integer power of 2 that is greater than or equal to n.

    Examples
    --------
    >>> ceil_pow_2(128.0)
    128.0
    >>> ceil_pow_2(0.125)
    0.125
    >>> ceil_pow_2(129.0)
    256.0
    >>> ceil_pow_2(0.126)
    0.25
    >>> ceil_pow_2(1.0)
    1.0

    """
    # frexp splits floats into mantissa and exponent, ldexp does the opposite.
    # For positive numbers, mantissa is in [0.5, 1.).
    mantissa, exponent = math.frexp(n)
    return math.ldexp(
        1 if mantissa >= 0 else float('nan'),
        exponent - 1 if mantissa == 0.5 else exponent
    )


def abscissa(series):
    """Produce the independent variable for a lal TimeSeries or
    FrequencySeries.
    """
    try:
        delta = series.deltaT
        x0 = float(series.epoch)
    except AttributeError:
        delta = series.deltaF
        x0 = series.f0
    return x0 + delta * np.arange(len(series.data.data))


def exp_i(phi):
    return np.cos(phi) + np.sin(phi) * 1j


def truncated_ifft(y, nsamples_out=None):
    r"""Truncated inverse FFT.

    See http://www.fftw.org/pruned.html for a discussion of related algorithms.

    Perform inverse FFT to obtain truncated autocorrelation time series.
    This makes use of a folded DFT for a speedup of::

        log(nsamples)/log(nsamples_out)

    over directly computing the inverse FFT and truncating. Here is how it
    works. Say we have a frequency-domain signal X[k], for 0 ≤ k ≤ N - 1. We
    want to compute its DFT x[n], for 0 ≤ n ≤ M, where N is divisible by M:
    N = cM, for some integer c. The DFT is::

               N - 1
               ______
               \           2 π i k n
        x[n] =  \     exp[-----------] Y[k]
               /               N
              /------
               k = 0

               c - 1   M - 1
               ______  ______
               \       \           2 π i n (m c + j)
             =  \       \     exp[------------------] Y[m c + j]
               /       /                 c M
              /------ /------
               j = 0   m = 0

               c - 1                     M - 1
               ______                    ______
               \           2 π i n j     \           2 π i n m
             =  \     exp[-----------]    \     exp[-----------] Y[m c + j]
               /               N         /               M
              /------                   /------
               j = 0                     m = 0

    So: we split the frequency series into c deinterlaced sub-signals, each of
    length M, compute the DFT of each sub-signal, and add them back together
    with complex weights.

    Parameters
    ----------
    y : `numpy.ndarray`
        Complex input vector.
    nsamples_out : int, optional
        Length of output vector. By default, same as length of input vector.

    Returns
    -------
    x : `numpy.ndarray`
        The first nsamples_out samples of the IFFT of x, zero-padded if

    Examples
    --------
    First generate the IFFT of a random signal:

    >>> nsamples_out = 1024
    >>> y = np.random.randn(nsamples_out) + np.random.randn(nsamples_out) * 1j
    >>> x = fft.ifft(y)

    Now check that the truncated IFFT agrees:

    >>> np.allclose(x, truncated_ifft(y), rtol=1e-15)
    True
    >>> np.allclose(x, truncated_ifft(y, 1024), rtol=1e-15)
    True
    >>> np.allclose(x[:128], truncated_ifft(y, 128), rtol=1e-15)
    True
    >>> np.allclose(x[:1], truncated_ifft(y, 1), rtol=1e-15)
    True
    >>> np.allclose(x[:32], truncated_ifft(y, 32), rtol=1e-15)
    True
    >>> np.allclose(x[:63], truncated_ifft(y, 63), rtol=1e-15)
    True
    >>> np.allclose(x[:25], truncated_ifft(y, 25), rtol=1e-15)
    True
    >>> truncated_ifft(y, 1025)
    Traceback (most recent call last):
      ...
    ValueError: Input is too short: you gave me an input of length 1024, but you asked for an IFFT of length 1025.

    """  # noqa: E501
    nsamples = len(y)
    if nsamples_out is None:
        nsamples_out = nsamples
    elif nsamples_out > nsamples:
        raise ValueError(
            'Input is too short: you gave me an input of length {0}, '
            'but you asked for an IFFT of length {1}.'.format(
                nsamples, nsamples_out))
    elif nsamples & (nsamples - 1):
        raise NotImplementedError(
            'I am too lazy to implement for nsamples that is '
            'not a power of 2.')

    # Find number of FFTs.
    # FIXME: only works if nsamples is a power of 2.
    # Would be better to find the smallest divisor of nsamples that is
    # greater than or equal to nsamples_out.
    nsamples_batch = int(ceil_pow_2(nsamples_out))
    c = nsamples // nsamples_batch

    # FIXME: Implement for real-to-complex FFTs as well.
    twiddle = exp_i(2 * np.pi * np.arange(nsamples_batch) / nsamples)

    x = fft.ifft(y.reshape(nsamples_batch, c).T)

    result = x[-1]
    for row in x[-2::-1]:
        result *= twiddle  # FIXME: check stability of this recurrence relation
        result += row

    # Now need to truncate remaining samples.
    if nsamples_out < nsamples_batch:
        result = result[:nsamples_out]

    return result / c


@lal_ndebug()
def get_approximant_and_orders_from_string(s):
    """Determine the approximant, amplitude order, and phase order for a string
    of the form "TaylorT4threePointFivePN". In this example, the waveform is
    "TaylorT4" and the phase order is 7 (twice 3.5). If the input contains the
    substring "restricted" or "Restricted", then the amplitude order is taken
    to be 0. Otherwise, the amplitude order is the same as the phase order.
    """
    # SWIG-wrapped functions apparently do not understand Unicode, but
    # often the input argument will come from a Unicode XML file.
    s = str(s)
    approximant = lalsimulation.GetApproximantFromString(s)
    try:
        phase_order = lalsimulation.GetOrderFromString(s)
    except RuntimeError:
        phase_order = -1
    if 'restricted' in s or 'Restricted' in s:
        amplitude_order = 0
    else:
        amplitude_order = phase_order
    return approximant, amplitude_order, phase_order


def get_f_lso(mass1, mass2):
    """Calculate the GW frequency during the last stable orbit of a compact
    binary.
    """
    return 1 / (6 ** 1.5 * np.pi * (mass1 + mass2) * lal.MTSUN_SI)


def sngl_inspiral_psd(waveform, mass1, mass2,
                      f_min=10, f_final=None, f_ref=None, **kwargs):
    # FIXME: uberbank mass criterion. Should find a way to get this from
    # pipeline output metadata.
    if waveform == 'o1-uberbank':
        log.warning('Template is unspecified; '
                    'using ER8/O1 uberbank criterion')
        if mass1 + mass2 < 4:
            waveform = 'TaylorF2threePointFivePN'
        else:
            waveform = 'SEOBNRv2_ROM_DoubleSpin'
    elif waveform == 'o2-uberbank':
        log.warning('Template is unspecified; '
                    'using ER10/O2 uberbank criterion')
        if mass1 + mass2 < 4:
            waveform = 'TaylorF2threePointFivePN'
        else:
            waveform = 'SEOBNRv4_ROM'
    approx, ampo, phaseo = get_approximant_and_orders_from_string(waveform)
    log.info('Selected template: %s', waveform)

    # Generate conditioned template.
    params = lal.CreateDict()
    lalsimulation.SimInspiralWaveformParamsInsertPNPhaseOrder(params, phaseo)
    lalsimulation.SimInspiralWaveformParamsInsertPNAmplitudeOrder(params, ampo)
    hplus, hcross = lalsimulation.SimInspiralFD(
        m1=float(mass1) * lal.MSUN_SI, m2=float(mass2) * lal.MSUN_SI,
        S1x=float(kwargs.get('spin1x') or 0),
        S1y=float(kwargs.get('spin1y') or 0),
        S1z=float(kwargs.get('spin1z') or 0),
        S2x=float(kwargs.get('spin2x') or 0),
        S2y=float(kwargs.get('spin2y') or 0),
        S2z=float(kwargs.get('spin2z') or 0),
        distance=1e6 * lal.PC_SI, inclination=0, phiRef=0,
        longAscNodes=0, eccentricity=0, meanPerAno=0,
        deltaF=0, f_min=f_min,
        # Note: code elsewhere ensures that the sample rate is at least two
        # times f_final; the factor of 2 below is just a safety factor to make
        # sure that the sample rate is 2-4 times f_final.
        f_max=ceil_pow_2(2 * (f_final or 2048)),
        f_ref=float(f_ref or 0),
        LALparams=params, approximant=approx)

    # Force `plus' and `cross' waveform to be in quadrature.
    h = 0.5 * (hplus.data.data + 1j * hcross.data.data)

    # For inspiral-only waveforms, nullify frequencies beyond ISCO.
    # FIXME: the waveform generation functions pick the end frequency
    # automatically. Shouldn't SimInspiralFD?
    inspiral_only_waveforms = (
        lalsimulation.TaylorF2,
        lalsimulation.SpinTaylorF2,
        lalsimulation.TaylorF2RedSpin,
        lalsimulation.TaylorF2RedSpinTidal,
        lalsimulation.SpinTaylorT4Fourier)
    if approx in inspiral_only_waveforms:
        h[abscissa(hplus) >= get_f_lso(mass1, mass2)] = 0

    # Throw away any frequencies above high frequency cutoff
    h[abscissa(hplus) >= (f_final or 2048)] = 0

    # Drop Nyquist frequency.
    if len(h) % 2:
        h = h[:-1]

    # Create output frequency series.
    psd = lal.CreateREAL8FrequencySeries(
        'signal PSD', 0, hplus.f0, hcross.deltaF, hplus.sampleUnits**2, len(h))
    psd.data.data = abs2(h)

    # Done!
    return psd


def signal_psd_series(H, S):
    n = H.data.data.size
    f = H.f0 + np.arange(1, n) * H.deltaF
    ret = lal.CreateREAL8FrequencySeries(
        'signal PSD / noise PSD', 0, H.f0, H.deltaF, lal.DimensionlessUnit, n)
    ret.data.data[0] = 0
    ret.data.data[1:] = H.data.data[1:] / S(f)
    return ret


def autocorrelation(H, out_duration, normalize=True):
    """Calculate the complex autocorrelation sequence a(t), for t >= 0, of an
    inspiral signal.

    Parameters
    ----------
    H : lal.REAL8FrequencySeries
        Signal PSD series.
    S : callable
        Noise power spectral density function.

    Returns
    -------
    acor : `numpy.ndarray`
        The complex-valued autocorrelation sequence.
    sample_rate : float
        The sample rate.

    """
    # Compute duration of template, rounded up to a power of 2.
    H_len = H.data.data.size
    nsamples = 2 * H_len
    sample_rate = nsamples * H.deltaF

    # Compute autopower spectral density.
    power = np.empty(nsamples, H.data.data.dtype)
    power[:H_len] = H.data.data
    power[H_len:] = 0

    # Determine length of output FFT.
    nsamples_out = int(np.ceil(out_duration * sample_rate))

    acor = truncated_ifft(power, nsamples_out)
    if normalize:
        acor /= np.abs(acor[0])

    # If we have done this right, then the zeroth sample represents lag 0
    if np.all(np.isreal(H.data.data)):
        assert np.argmax(np.abs(acor)) == 0
        assert np.isreal(acor[0])

    # Done!
    return acor, float(sample_rate)


def abs2(y):
    """Return the absolute value squared, :math:`|z|^2` ,for a complex number
    :math:`z`, without performing a square root.
    """
    return np.square(y.real) + np.square(y.imag)


class vectorize_swig_psd_func:  # noqa: N801
    """Create a vectorized Numpy function from a SWIG-wrapped PSD function.
    SWIG does not provide enough information for Numpy to determine the number
    of input arguments, so we can't just use np.vectorize.
    """

    def __init__(self, str):
        self.__func = getattr(lalsimulation, str + 'Ptr')
        self.__npyfunc = np.frompyfunc(getattr(lalsimulation, str), 1, 1)

    def __call__(self, f):
        fa = np.asarray(f)
        df = np.diff(fa)
        if fa.ndim == 1 and df.size > 1 and np.all(df[0] == df[1:]):
            fa = np.concatenate((fa, [fa[-1] + df[0]]))
            ret = lal.CreateREAL8FrequencySeries(
                None, 0, fa[0], df[0], lal.DimensionlessUnit, fa.size)
            lalsimulation.SimNoisePSD(ret, 0, self.__func)
            ret = ret.data.data[:-1]
        else:
            ret = self.__npyfunc(f)
        if not np.isscalar(ret):
            ret = ret.astype(float)
        return ret


class InterpolatedPSD(interpolate.interp1d):
    """Create a (linear in log-log) interpolating function for a discretely
    sampled power spectrum S(f).
    """

    def __init__(self, f, S, f_high_truncate=1.0, fill_value=np.inf):
        assert f_high_truncate <= 1.0
        f = np.asarray(f)
        S = np.asarray(S)

        # Exclude DC if present
        if f[0] == 0:
            f = f[1:]
            S = S[1:]
        # FIXME: This is a hack to fix an issue with the detection pipeline's
        # PSD conditioning. Remove this when the issue is fixed upstream.
        if f_high_truncate < 1.0:
            log.warning(
                'Truncating PSD at %g of maximum frequency to suppress '
                'rolloff artifacts. This option may be removed in the future.',
                f_high_truncate)
            keep = (f <= f_high_truncate * max(f))
            f = f[keep]
            S = S[keep]
        super().__init__(
            np.log(f), np.log(S),
            kind='linear', bounds_error=False, fill_value=np.log(fill_value))
        self._f_min = min(f)
        self._f_max = max(f)

    @property
    def f_min(self):
        return self._f_min

    @property
    def f_max(self):
        return self._f_max

    def __call__(self, f):
        f_min = np.min(f)
        f_max = np.max(f)
        if f_min < self._f_min:
            log.warning('Assuming PSD is infinite at %g Hz because PSD is '
                        'only sampled down to %g Hz', f_min, self._f_min)
        if f_max > self._f_max:
            log.warning('Assuming PSD is infinite at %g Hz because PSD is '
                        'only sampled up to %g Hz', f_max, self._f_max)
        return np.where(
            (f >= self._f_min) & (f <= self._f_max),
            np.exp(super().__call__(np.log(f))),
            np.exp(self.fill_value))


class SignalModel:
    """Class to speed up computation of signal/noise-weighted integrals and
    Barankin and Cramér-Rao lower bounds on time and phase estimation.

    Note that the autocorrelation series and the moments are related,
    as shown below.

    Examples
    --------
    Create signal model:

    >>> from . import filter
    >>> sngl = lambda: None
    >>> H = filter.sngl_inspiral_psd(
    ...     'TaylorF2threePointFivePN', mass1=1.4, mass2=1.4)
    >>> S = vectorize_swig_psd_func('SimNoisePSDaLIGOZeroDetHighPower')
    >>> W = filter.signal_psd_series(H, S)
    >>> sm = SignalModel(W)

    Compute one-sided autocorrelation function:

    >>> out_duration = 0.1
    >>> a, sample_rate = filter.autocorrelation(W, out_duration)

    Restore negative time lags using symmetry:

    >>> a = np.concatenate((a[:0:-1].conj(), a))

    Compute the first 2 frequency moments by taking derivatives of the
    autocorrelation sequence using centered finite differences.
    The nth frequency moment should be given by (-1j)^n a^(n)(t).

    >>> acor_moments = []
    >>> for i in range(2):
    ...     acor_moments.append(a[len(a) // 2])
    ...     a = -0.5j * sample_rate * (a[2:] - a[:-2])
    >>> assert np.all(np.isreal(acor_moments))
    >>> acor_moments = np.real(acor_moments)

    Compute the first 2 frequency moments using this class.

    >>> quad_moments = [sm.get_sn_moment(i) for i in range(2)]

    Compare them.

    >>> for i, (am, qm) in enumerate(zip(acor_moments, quad_moments)):
    ...     assert np.allclose(am, qm, rtol=0.05)

    """

    def __init__(self, h):
        """Create a TaylorF2 signal model with the given masses, PSD function
        S(f), PN amplitude order, and low-frequency cutoff.
        """
        # Find indices of first and last nonzero samples.
        nonzero = np.flatnonzero(h.data.data)
        first_nonzero = nonzero[0]
        last_nonzero = nonzero[-1]

        # Frequency sample points
        self.dw = 2 * np.pi * h.deltaF
        f = h.f0 + h.deltaF * np.arange(first_nonzero, last_nonzero + 1)
        self.w = 2 * np.pi * f

        # Throw away leading and trailing zeros.
        h = h.data.data[first_nonzero:last_nonzero + 1]

        self.denom_integrand = 4 / (2 * np.pi) * h
        self.den = np.trapz(self.denom_integrand, dx=self.dw)

    def get_horizon_distance(self, snr_thresh=1):
        return np.sqrt(self.den) / snr_thresh

    def get_sn_average(self, func):
        """Get the average of a function of angular frequency, weighted by the
        signal to noise per unit angular frequency.
        """
        num = np.trapz(func(self.w) * self.denom_integrand, dx=self.dw)
        return num / self.den

    def get_sn_moment(self, power):
        """Get the average of angular frequency to the given power, weighted by
        the signal to noise per unit frequency.
        """
        return self.get_sn_average(lambda w: w**power)

    def get_crb(self, snr):
        """Get the Cramér-Rao bound, or inverse Fisher information matrix,
        describing the phase and time estimation covariance.
        """
        w1 = self.get_sn_moment(1)
        w2 = self.get_sn_moment(2)
        fisher = np.asarray(((1, -w1), (-w1, w2)))
        return linalg.inv(fisher) / np.square(snr)

    # FIXME: np.vectorize doesn't work on unbound instance methods. The
    # excluded keyword, added in Numpy 1.7, could be used here to exclude the
    # zeroth argument, self.
    def __get_crb_toa_uncert(self, snr):
        return np.sqrt(self.get_crb(snr)[1, 1])

    def get_crb_toa_uncert(self, snr):
        return np.frompyfunc(self.__get_crb_toa_uncert, 1, 1)(snr)

ligo/skymap/bayestar/interpolation.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Sub-sample interpolation for matched filter time series.

Example
-------
.. plot::
   :context: reset
   :include-source:
   :align: center

    from ligo.skymap.bayestar.interpolation import interpolate_max
    from matplotlib import pyplot as plt
    import numpy as np

    z = np.asarray([ 9.135017 -2.8185585j,  9.995214 -1.1222992j,
                    10.682851 +0.8188147j, 10.645139 +3.0268786j,
                     9.713133 +5.5589147j,  7.9043484+7.9039335j,
                     5.511646 +9.333084j ,  2.905198 +9.715742j ,
                     0.5302934+9.544538j ])

    amp = np.abs(z)
    arg = np.rad2deg(np.unwrap(np.angle(z)))
    arg -= (np.median(arg) // 360) * 360
    imax = np.argmax(amp)
    window = 4

    fig, (ax_amp, ax_arg) = plt.subplots(2, 1, figsize=(5, 6), sharex=True)
    ax_arg.set_xlabel('Sample index')
    ax_amp.set_ylabel('Amplitude')
    ax_arg.set_ylabel('Phase')
    args, kwargs = ('.-',), dict(color='lightgray', label='data')
    ax_amp.plot(amp, *args, **kwargs)
    ax_arg.plot(arg, *args, **kwargs)
    for method in ['lanczos', 'catmull-rom',
                   'quadratic-fit', 'nearest-neighbor']:
        i, y = interpolate_max(imax, z, window, method)
        amp = np.abs(y)
        arg = np.rad2deg(np.angle(y))
        args, kwargs = ('o',), dict(mfc='none', label=method)
        ax_amp.plot(i, amp, *args, **kwargs)
        ax_arg.plot(i, arg, *args, **kwargs)
    ax_arg.legend()
    fig.tight_layout()

"""
import numpy as np
from scipy import optimize

from .filter import abs2, exp_i, unwrap

__all__ = ('interpolate_max',)


#
# Lanczos interpolation
#


def lanczos(t, a):
    """The Lanczos kernel."""
    return np.where(np.abs(t) < a, np.sinc(t) * np.sinc(t / a), 0)


def lanczos_interpolant(t, y):
    """An interpolant constructed by convolution of the Lanczos kernel with
    a set of discrete samples at unit intervals.
    """
    a = len(y) // 2
    return sum(lanczos(t - i + a, a) * yi for i, yi in enumerate(y))


def lanczos_interpolant_utility_func(t, y):
    """Utility function for Lanczos interpolation."""
    return -abs2(lanczos_interpolant(t, y))


def interpolate_max_lanczos(imax, y, window_length):
    """Find the time and maximum absolute value of a time series by Lanczos
    interpolation.
    """
    yi = y[(imax - window_length):(imax + window_length + 1)]
    tmax = optimize.fminbound(
        lanczos_interpolant_utility_func, -1., 1., (yi,), xtol=1e-5)
    tmax = tmax.item()
    ymax = lanczos_interpolant(tmax, yi).item()
    return imax + tmax, ymax


#
# Catmull-Rom spline interpolation, real and imaginary parts
#


def poly_catmull_rom(y):
    return np.poly1d([
        -0.5 * y[0] + 1.5 * y[1] - 1.5 * y[2] + 0.5 * y[3],
        y[0] - 2.5 * y[1] + 2 * y[2] - 0.5 * y[3],
        -0.5 * y[0] + 0.5 * y[2],
        y[1]
    ])


def interpolate_max_catmull_rom_even(y):

    # Construct Catmull-Rom interpolating polynomials for
    # real and imaginary parts
    poly_re = poly_catmull_rom(y.real)
    poly_im = poly_catmull_rom(y.imag)

    # Find the roots of d(|y|^2)/dt as approximated
    roots = (poly_re * poly_re.deriv() + poly_im * poly_im.deriv()).r

    # Find which of the two matched interior points has a greater magnitude
    t_max = 0.
    y_max = y[1]
    y_max_abs2 = abs2(y_max)

    new_t_max = 1.
    new_y_max = y[2]
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    # Find any real root in (0, 1) that has a magnitude greater than the
    # greatest endpoint
    for root in roots:
        if np.isreal(root) and 0 < root < 1:
            new_t_max = np.real(root)
            new_y_max = poly_re(new_t_max) + poly_im(new_t_max) * 1j
            new_y_max_abs2 = abs2(new_y_max)
            if new_y_max_abs2 > y_max_abs2:
                t_max = new_t_max
                y_max = new_y_max
                y_max_abs2 = new_y_max_abs2

    # Done
    return t_max, y_max


def interpolate_max_catmull_rom(imax, y, window_length):
    t_max, y_max = interpolate_max_catmull_rom_even(y[imax - 2:imax + 2])
    y_max_abs2 = abs2(y_max)
    t_max = t_max - 1

    new_t_max, new_y_max = interpolate_max_catmull_rom_even(
        y[imax - 1:imax + 3])
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    return imax + t_max, y_max


#
# Catmull-Rom spline interpolation, amplitude and phase
#


def interpolate_max_catmull_rom_amp_phase_even(y):

    # Construct Catmull-Rom interpolating polynomials for
    # real and imaginary parts
    poly_abs = poly_catmull_rom(np.abs(y))
    poly_arg = poly_catmull_rom(unwrap(np.angle(y)))

    # Find the roots of d(|y|)/dt as approximated
    roots = poly_abs.r

    # Find which of the two matched interior points has a greater magnitude
    t_max = 0.
    y_max = y[1]
    y_max_abs2 = abs2(y_max)

    new_t_max = 1.
    new_y_max = y[2]
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    # Find any real root in (0, 1) that has a magnitude greater than the
    # greatest endpoint
    for root in roots:
        if np.isreal(root) and 0 < root < 1:
            new_t_max = np.real(root)
            new_y_max = poly_abs(new_t_max) * exp_i(poly_arg(new_t_max))
            new_y_max_abs2 = abs2(new_y_max)
            if new_y_max_abs2 > y_max_abs2:
                t_max = new_t_max
                y_max = new_y_max
                y_max_abs2 = new_y_max_abs2

    # Done
    return t_max, y_max


def interpolate_max_catmull_rom_amp_phase(imax, y, window_length):
    t_max, y_max = interpolate_max_catmull_rom_amp_phase_even(
        y[imax - 2:imax + 2])
    y_max_abs2 = abs2(y_max)
    t_max = t_max - 1

    new_t_max, new_y_max = interpolate_max_catmull_rom_amp_phase_even(
        y[imax - 1:imax + 3])
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    return imax + t_max, y_max


#
# Quadratic fit
#


def interpolate_max_quadratic_fit(imax, y, window_length):
    """Quadratic fit to absolute value of y. Note that this one does not alter
    the value at the maximum.
    """
    t = np.arange(-window_length, window_length + 1.)
    y = y[imax - window_length:imax + window_length + 1]
    y_abs = np.abs(y)
    a, b, c = np.polyfit(t, y_abs, 2)

    # Find which of the two matched interior points has a greater magnitude
    t_max = -1.
    y_max = y[window_length - 1]
    y_max_abs = y_abs[window_length - 1]

    new_t_max = 1.
    new_y_max = y[window_length + 1]
    new_y_max_abs = y_abs[window_length + 1]

    if new_y_max_abs > y_max_abs:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs = new_y_max_abs

    # Determine if the global extremum of the polynomial is a
    # local maximum in (-1, 1)
    new_t_max = -0.5 * b / a
    new_y_max_abs = c - 0.25 * np.square(b) / a
    if -1 < new_t_max < 1 and new_y_max_abs > y_max_abs:
        t_max = new_t_max
        y_max_abs = new_y_max_abs
        y_phase = np.interp(t_max, t, np.unwrap(np.angle(y)))
        y_max = y_max_abs * exp_i(y_phase)

    return imax + t_max, y_max


#
# Nearest neighbor interpolation
#


def interpolate_max_nearest_neighbor(imax, y, window_length):
    """Trivial, nearest-neighbor interpolation."""
    return imax, y[imax]


#
# Set default interpolation scheme
#


_interpolants = {
    'catmull-rom-amp-phase': interpolate_max_catmull_rom_amp_phase,
    'catmull-rom': interpolate_max_catmull_rom,
    'lanczos': interpolate_max_lanczos,
    'nearest-neighbor': interpolate_max_nearest_neighbor,
    'quadratic-fit': interpolate_max_quadratic_fit}


def interpolate_max(imax, y, window_length, method='catmull-rom-amp-phase'):
    """Perform sub-sample interpolation to find the phase and amplitude
    at the maximum of the absolute value of a complex series.

    Parameters
    ----------
    imax : int
        The index of the maximum sample in the series.
    y : `numpy.ndarray`
        The complex series.
    window_length : int
        The window of the interpolation function for the `lanczos` and
        `quadratic-fit` methods. The interpolation will consider a sliding
        window of `2 * window_length + 1` samples centered on `imax`.
    method : {'catmull-rom-amp-phase', 'catmull-rom', 'lanczos', 'nearest-neighbor', 'quadratic-fit'}
        The interpolation method. Can be any of the following:

        * `catmull-rom-amp-phase`: Catmull-Rom cubic splines on amplitude and phase
          The `window_length` parameter is ignored (understood to be 2).
        * `catmull-rom`: Catmull-Rom cubic splines on real and imaginary parts
          The `window_length` parameter is ignored (understood to be 2).
        * `lanczos`: Lanczos filter interpolation
        * `nearest-neighbor`: Nearest neighbor (e.g., no interpolation).
          The `window_length` parameter is ignored (understood to be 0).
        * `quadratic-fit`: Fit the absolute value of the SNR to a quadratic
          function and the phase to a linear function.

    Returns
    -------
    imax_interp : float
        The interpolated index of the maximum sample, which should be between
        `imax - 0.5` and `imax + 0.5`.
    ymax_interp : complex
        The interpolated value at the maximum.

    """  # noqa: E501
    return _interpolants[method](imax, np.asarray(y), window_length)

ligo/skymap/bayestar/ptemcee.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
# FIXME: Remove this file if https://github.com/willvousden/ptemcee/pull/6
# is merged
import numpy as np
import ptemcee.sampler

__all__ = ('Sampler',)


class VectorLikePriorEvaluator(ptemcee.sampler.LikePriorEvaluator):

    def __call__(self, x):
        s = x.shape
        x = x.reshape((-1, x.shape[-1]))

        lp = self.logp(x, *self.logpargs, **self.logpkwargs)
        if np.any(np.isnan(lp)):
            raise ValueError('Prior function returned NaN.')

        ll = np.empty_like(lp)
        bad = (lp == -np.inf)
        ll[bad] = 0
        ll[~bad] = self.logl(x[~bad], *self.loglargs, **self.loglkwargs)
        if np.any(np.isnan(ll)):
            raise ValueError('Log likelihood function returned NaN.')

        return ll.reshape(s[:-1]), lp.reshape(s[:-1])


class Sampler(ptemcee.sampler.Sampler):
    """Patched version of :class:`ptemcee.Sampler` that supports the
    `vectorize` option of :class:`emcee.EnsembleSampler`.
    """

    def __init__(self, nwalkers, dim, logl, logp,  # noqa: N803
                 ntemps=None, Tmax=None, betas=None,  # noqa: N803
                 threads=1, pool=None, a=2.0,
                 loglargs=[], logpargs=[],
                 loglkwargs={}, logpkwargs={},
                 adaptation_lag=10000, adaptation_time=100,
                 random=None, vectorize=False):
        super().__init__(nwalkers, dim, logl, logp,
                         ntemps=ntemps, Tmax=Tmax, betas=betas,
                         threads=threads, pool=pool, a=a, loglargs=loglargs,
                         logpargs=logpargs, loglkwargs=loglkwargs,
                         logpkwargs=logpkwargs, adaptation_lag=adaptation_lag,
                         adaptation_time=adaptation_time, random=random)
        self._vectorize = vectorize
        if vectorize:
            self._likeprior = VectorLikePriorEvaluator(logl, logp,
                                                       loglargs, logpargs,
                                                       loglkwargs, logpkwargs)

    def _evaluate(self, ps):
        if self._vectorize:
            return self._likeprior(ps)
        else:
            return super().evaluate(ps)

ligo/skymap/coordinates/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/coordinates/detector.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Astropy coordinate frames to visualize triangulation rings from pairs of
gravitational-wave detectors. These are useful for generating plots similar to
Fig. 2 of the GW150914 localization and follow-up paper [1]_.

Example
-------
.. plot::
   :context: reset
   :include-source:
   :align: center

    from astropy.coordinates import EarthLocation
    from astropy.time import Time
    from ligo.skymap.coordinates import DetectorFrame
    from ligo.skymap.io import read_sky_map
    import ligo.skymap.plot
    from matplotlib import pyplot as plt

    # Download GW150914 localization
    url = 'https://dcc.ligo.org/public/0122/P1500227/012/bayestar_gstlal_C01.fits.gz'
    m, meta = ligo.skymap.io.read_sky_map(url)

    # Plot sky map on an orthographic projection
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(
        111, projection='astro globe', center='130d -70d')
    ax.imshow_hpx(m, cmap='cylon')

    # Hide the original ('RA', 'Dec') ticks
    for coord in ax.coords:
        coord.set_ticks_visible(False)
        coord.set_ticklabel_visible(False)

    # Construct Hanford-Livingston detector frame at the time of the event
    frame = DetectorFrame(site_1=EarthLocation.of_site('H1'),
                          site_2=EarthLocation.of_site('L1'),
                          obstime=Time(meta['gps_time'], format='gps'))

    # Draw grid for detector frame
    ax.get_coords_overlay(frame).grid()

References
----------
.. [1] LSC/Virgo et al., 2016. "Localization and Broadband Follow-up of the
       Gravitational-wave Transient GW150914." ApJL 826, L13.
       :doi:`10.3847/2041-8205/826/1/L13`

"""  # noqa: E501
from astropy.coordinates import (
    CartesianRepresentation, DynamicMatrixTransform, EarthLocationAttribute,
    frame_transform_graph, ITRS, SphericalRepresentation)
from astropy.coordinates.matrix_utilities import matrix_transpose
from astropy import units as u
import numpy as np

__all__ = ('DetectorFrame',)


class DetectorFrame(ITRS):
    """A coordinate frames to visualize triangulation rings from pairs of
    gravitational-wave detectors.
    """

    site_1 = EarthLocationAttribute()
    site_2 = EarthLocationAttribute()

    default_representation = SphericalRepresentation


@frame_transform_graph.transform(DynamicMatrixTransform, ITRS, DetectorFrame)
def itrs_to_detectorframe(from_coo, to_frame):
    e_z = CartesianRepresentation(u.Quantity(to_frame.site_1.geocentric) -
                                  u.Quantity(to_frame.site_2.geocentric))
    e_z /= e_z.norm()
    e_x = CartesianRepresentation(0, 0, 1).cross(e_z)
    e_x /= e_x.norm()
    e_y = e_z.cross(e_x)

    return np.row_stack((e_x.xyz.value,
                         e_y.xyz.value,
                         e_z.xyz.value))


@frame_transform_graph.transform(DynamicMatrixTransform, DetectorFrame, ITRS)
def detectorframe_to_itrs(from_coo, to_frame):
    return matrix_transpose(itrs_to_detectorframe(to_frame, from_coo))

ligo/skymap/coordinates/eigenframe.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
#
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Astropy coordinate frame for eigendecomposition of a cloud of points or a 3D
sky map.
"""

from astropy.coordinates import (
    BaseCoordinateFrame, CartesianRepresentation, DynamicMatrixTransform,
    frame_transform_graph, ICRS, SphericalRepresentation)
from astropy.coordinates import CartesianRepresentationAttribute
from astropy.units import dimensionless_unscaled
import numpy as np

from ..distance import principal_axes

__all__ = ('EigenFrame',)


class EigenFrame(BaseCoordinateFrame):
    """A coordinate frame that has its axes aligned with the principal
    components of a cloud of points.
    """

    e_x = CartesianRepresentationAttribute(
        default=CartesianRepresentation(1, 0, 0, unit=dimensionless_unscaled),
        unit=dimensionless_unscaled)
    e_y = CartesianRepresentationAttribute(
        default=CartesianRepresentation(0, 1, 0, unit=dimensionless_unscaled),
        unit=dimensionless_unscaled)
    e_z = CartesianRepresentationAttribute(
        default=CartesianRepresentation(0, 0, 1, unit=dimensionless_unscaled),
        unit=dimensionless_unscaled)

    default_representation = SphericalRepresentation

    @classmethod
    def for_coords(cls, coords):
        """Create a coordinate frame that has its axes aligned with the
        principal components of a cloud of points.

        Parameters
        ----------
        coords : `astropy.coordinates.SkyCoord`
            A cloud of points

        Returns
        -------
        frame : `EigenFrame`
            A new coordinate frame

        """
        v = coords.icrs.cartesian.xyz.value
        _, r = np.linalg.eigh(np.dot(v, v.T))
        r = r[:, ::-1]  # Order by descending eigenvalue
        e_x, e_y, e_z = CartesianRepresentation(r, unit=dimensionless_unscaled)
        return cls(e_x=e_x, e_y=e_y, e_z=e_z)

    @classmethod
    def for_skymap(cls, prob, distmu, distsigma, nest=False):
        """Create a coordinate frame that has its axes aligned with the
        principal components of a 3D sky map.

        Parameters
        ----------
        prob : `numpy.ndarray`
            Marginal probability (pix^-2)
        distmu : `numpy.ndarray`
            Distance location parameter (Mpc)
        distsigma : `numpy.ndarray`
            Distance scale parameter (Mpc)
        distnorm : `numpy.ndarray`
            Distance normalization factor (Mpc^-2)
        nest : bool, default=False
            Indicates whether the input sky map is in nested rather than
            ring-indexed HEALPix coordinates (default: ring).

        Returns
        -------
        frame : `EigenFrame`
            A new coordinate frame

        """
        r = principal_axes(prob, distmu, distsigma, nest=nest)
        r = r[:, ::-1]  # Order by descending eigenvalue
        e_x, e_y, e_z = CartesianRepresentation(r, unit=dimensionless_unscaled)
        return cls(e_x=e_x, e_y=e_y, e_z=e_z)


@frame_transform_graph.transform(DynamicMatrixTransform, ICRS, EigenFrame)
def icrs_to_eigenframe(from_coo, to_frame):
    return np.row_stack((to_frame.e_x.xyz.value,
                         to_frame.e_y.xyz.value,
                         to_frame.e_z.xyz.value))


@frame_transform_graph.transform(DynamicMatrixTransform, EigenFrame, ICRS)
def eigenframe_to_icrs(from_coo, to_frame):
    return np.column_stack((from_coo.e_x.xyz.value,
                            from_coo.e_y.xyz.value,
                            from_coo.e_z.xyz.value))

ligo/skymap/io/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/io/fits.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
#!/usr/bin/env python
#
# Copyright (C) 2013-2022  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Reading and writing HEALPix FITS files.

An example FITS header looks like this:

.. code-block:: sh

    $ fitsheader test.fits.gz
    # HDU 0 in test.fits.gz
    SIMPLE  =                    T / conforms to FITS standard
    BITPIX  =                    8 / array data type
    NAXIS   =                    0 / number of array dimensions
    EXTEND  =                    T

    # HDU 1 in test.fits.gz
    XTENSION= 'BINTABLE'           / binary table extension
    BITPIX  =                    8 / array data type
    NAXIS   =                    2 / number of array dimensions
    NAXIS1  =                 4096 / length of dimension 1
    NAXIS2  =                  192 / length of dimension 2
    PCOUNT  =                    0 / number of group parameters
    GCOUNT  =                    1 / number of groups
    TFIELDS =                    1 / number of table fields
    TTYPE1  = 'PROB    '
    TFORM1  = '1024E   '
    TUNIT1  = 'pix-1   '
    PIXTYPE = 'HEALPIX '           / HEALPIX pixelisation
    ORDERING= 'RING    '           / Pixel ordering scheme, either RING or NESTED
    COORDSYS= 'C       '           / Ecliptic, Galactic or Celestial (equatorial)
    EXTNAME = 'xtension'           / name of this binary table extension
    NSIDE   =                  128 / Resolution parameter of HEALPIX
    FIRSTPIX=                    0 / First pixel # (0 based)
    LASTPIX =               196607 / Last pixel # (0 based)
    INDXSCHM= 'IMPLICIT'           / Indexing: IMPLICIT or EXPLICIT
    OBJECT  = 'FOOBAR 12345'       / Unique identifier for this event
    REFERENC= 'http://www.youtube.com/watch?v=0ccKPSVQcFk' / URL of this event
    DATE-OBS= '2013-04-08T21:37:32.25' / UTC date of the observation
    MJD-OBS =      56391.151064815 / modified Julian date of the observation
    DATE    = '2013-04-08T21:50:32' / UTC date of file creation
    CREATOR = 'fits.py '           / Program that created this file
    RUNTIME =                 21.5 / Runtime in seconds of the CREATOR program
"""  # noqa: E501

import logging
import healpy as hp
import numpy as np
from astropy.io import fits
from astropy.time import Time
from astropy import units as u
from ligo.lw import lsctables
import itertools
import astropy_healpix as ah
from astropy.table import Table
from .. import moc
from ..util.ilwd import ilwd_to_int

log = logging.getLogger()

__all__ = ("read_sky_map", "write_sky_map")


def gps_to_iso8601(gps_time):
    """Convert a floating-point GPS time in seconds to an ISO 8601 date string.

    Parameters
    ----------
    gps : float
        Time in seconds since GPS epoch

    Returns
    -------
    iso8601 : str
        ISO 8601 date string (with fractional seconds)

    Examples
    --------
    >>> gps_to_iso8601(1000000000.01)
    '2011-09-14T01:46:25.010000'
    >>> gps_to_iso8601(1000000000)
    '2011-09-14T01:46:25.000000'
    >>> gps_to_iso8601(1000000000.999999)
    '2011-09-14T01:46:25.999999'
    >>> gps_to_iso8601(1000000000.9999999)
    '2011-09-14T01:46:26.000000'
    >>> gps_to_iso8601(1000000814.999999)
    '2011-09-14T01:59:59.999999'
    >>> gps_to_iso8601(1000000814.9999999)
    '2011-09-14T02:00:00.000000'

    """
    return Time(float(gps_time), format='gps', precision=6).utc.isot


def iso8601_to_gps(iso8601):
    """Convert an ISO 8601 date string to a floating-point GPS time in seconds.

    Parameters
    ----------
    iso8601 : str
        ISO 8601 date string (with fractional seconds)

    Returns
    -------
    gps : float
        Time in seconds since GPS epoch

    Examples
    --------
    >>> gps_to_iso8601(1129501781.2)
    '2015-10-21T22:29:24.200000'
    >>> iso8601_to_gps('2015-10-21T22:29:24.2')
    1129501781.2

    """
    return Time(iso8601, scale='utc').gps


def gps_to_mjd(gps_time):
    """Convert a floating-point GPS time in seconds to a modified Julian day.

    Parameters
    ----------
    gps_time : float
        Time in seconds since GPS epoch

    Returns
    -------
    mjd : float
        Modified Julian day

    Examples
    --------
    >>> '%.9f' % round(gps_to_mjd(1129501781.2), 9)
    '57316.937085648'

    """
    return Time(gps_time, format='gps').utc.mjd


def identity(x):
    return x


def instruments_to_fits(value):
    if not isinstance(value, str):
        value = str(lsctables.instrumentsproperty.set(value))
    return value


def instruments_from_fits(value):
    return {str(ifo) for ifo in lsctables.instrumentsproperty.get(value)}


def metadata_for_version_module(version):
    return {'vcs_version': version.__spec__.parent + ' ' + version.version}


def normalize_objid(objid):
    try:
        return int(objid)
    except ValueError:
        try:
            return ilwd_to_int(objid)
        except ValueError:
            return str(objid)


DEFAULT_NUNIQ_NAMES = ('PROBDENSITY', 'DISTMU', 'DISTSIGMA', 'DISTNORM')
DEFAULT_NUNIQ_UNITS = (u.steradian**-1, u.Mpc, u.Mpc, u.Mpc**-2)
DEFAULT_NESTED_NAMES = ('PROB', 'DISTMU', 'DISTSIGMA', 'DISTNORM')
DEFAULT_NESTED_UNITS = (u.pix**-1, u.Mpc, u.Mpc, u.Mpc**-2)
FITS_META_MAPPING = (
    ('objid', 'OBJECT', 'Unique identifier for this event',
     normalize_objid, normalize_objid),
    ('url', 'REFERENC', 'URL of this event', identity, identity),
    ('instruments', 'INSTRUME', 'Instruments that triggered this event',
     instruments_to_fits, instruments_from_fits),
    ('gps_time', 'DATE-OBS', 'UTC date of the observation',
     gps_to_iso8601, iso8601_to_gps),
    ('gps_time', 'MJD-OBS', 'modified Julian date of the observation',
     gps_to_mjd, None),
    ('gps_creation_time', 'DATE', 'UTC date of file creation',
     gps_to_iso8601, iso8601_to_gps),
    ('creator', 'CREATOR', 'Program that created this file',
     identity, identity),
    ('origin', 'ORIGIN', 'Organization responsible for this FITS file',
     identity, identity),
    ('runtime', 'RUNTIME', 'Runtime in seconds of the CREATOR program',
     identity, identity),
    ('distmean', 'DISTMEAN', 'Posterior mean distance (Mpc)',
     identity, identity),
    ('diststd', 'DISTSTD', 'Posterior standard deviation of distance (Mpc)',
     identity, identity),
    ('log_bci', 'LOGBCI', 'Log Bayes factor: coherent vs. incoherent',
     identity, identity),
    ('log_bsn', 'LOGBSN', 'Log Bayes factor: signal vs. noise',
     identity, identity),
    ('vcs_version', 'VCSVERS', 'Software version',
     identity, identity),
    ('vcs_revision', 'VCSREV', 'Software revision (Git)',
     identity, identity),
    ('build_date', 'DATE-BLD', 'Software build date',
     identity, identity))


def write_sky_map(filename, m, **kwargs):
    """Write a gravitational-wave sky map to a file, populating the header
    with optional metadata.

    Parameters
    ----------
    filename: str
        Path to the optionally gzip-compressed FITS file.

    m : `astropy.table.Table`, `numpy.array`
        If a Numpy record array or astorpy.table.Table instance, and has a
        column named 'UNIQ', then interpret the input as NUNIQ-style
        multi-order map [1]_. Otherwise, interpret as a NESTED or RING ordered
        map.

    **kwargs
        Additional metadata to add to FITS header. If m is an
        `astropy.table.Table` instance, then the header is initialized from
        both `m.meta` and `kwargs`.

    References
    ----------
    .. [1] Górski, K.M., Wandelt, B.D., Hivon, E., Hansen, F.K., & Banday, A.J.
        2017. The HEALPix Primer. The Unique Identifier scheme.
        http://healpix.sourceforge.net/html/intronode4.htm#SECTION00042000000000000000

    Examples
    --------
    Test header contents:

    >>> order = 9
    >>> nside = 2 ** order
    >>> npix = ah.nside_to_npix(nside)
    >>> prob = np.ones(npix, dtype=float) / npix

    >>> import tempfile
    >>> from ligo.skymap import version
    >>> with tempfile.NamedTemporaryFile(suffix='.fits') as f:
    ...     write_sky_map(f.name, prob, nest=True,
    ...                   vcs_version='foo 1.0', vcs_revision='bar',
    ...                   build_date='2018-01-01T00:00:00')
    ...     for card in fits.getheader(f.name, 1).cards:
    ...         print(str(card).rstrip())
    XTENSION= 'BINTABLE'           / binary table extension
    BITPIX  =                    8 / array data type
    NAXIS   =                    2 / number of array dimensions
    NAXIS1  =                    8 / length of dimension 1
    NAXIS2  =              3145728 / length of dimension 2
    PCOUNT  =                    0 / number of group parameters
    GCOUNT  =                    1 / number of groups
    TFIELDS =                    1 / number of table fields
    TTYPE1  = 'PROB    '
    TFORM1  = 'D       '
    TUNIT1  = 'pix-1   '
    PIXTYPE = 'HEALPIX '           / HEALPIX pixelisation
    ORDERING= 'NESTED  '           / Pixel ordering scheme: RING, NESTED, or NUNIQ
    COORDSYS= 'C       '           / Ecliptic, Galactic or Celestial (equatorial)
    NSIDE   =                  512 / Resolution parameter of HEALPIX
    INDXSCHM= 'IMPLICIT'           / Indexing: IMPLICIT or EXPLICIT
    VCSVERS = 'foo 1.0 '           / Software version
    VCSREV  = 'bar     '           / Software revision (Git)
    DATE-BLD= '2018-01-01T00:00:00' / Software build date

    >>> uniq = moc.nest2uniq(np.uint8(order), np.arange(npix))
    >>> probdensity = prob / hp.nside2pixarea(nside)
    >>> moc_data = np.rec.fromarrays(
    ...     [uniq, probdensity], names=['UNIQ', 'PROBDENSITY'])
    >>> with tempfile.NamedTemporaryFile(suffix='.fits') as f:
    ...     write_sky_map(f.name, moc_data,
    ...                   vcs_version='foo 1.0', vcs_revision='bar',
    ...                   build_date='2018-01-01T00:00:00')
    ...     for card in fits.getheader(f.name, 1).cards:
    ...         print(str(card).rstrip())
    XTENSION= 'BINTABLE'           / binary table extension
    BITPIX  =                    8 / array data type
    NAXIS   =                    2 / number of array dimensions
    NAXIS1  =                   16 / length of dimension 1
    NAXIS2  =              3145728 / length of dimension 2
    PCOUNT  =                    0 / number of group parameters
    GCOUNT  =                    1 / number of groups
    TFIELDS =                    2 / number of table fields
    TTYPE1  = 'UNIQ    '
    TFORM1  = 'K       '
    TTYPE2  = 'PROBDENSITY'
    TFORM2  = 'D       '
    TUNIT2  = 'sr-1    '
    PIXTYPE = 'HEALPIX '           / HEALPIX pixelisation
    ORDERING= 'NUNIQ   '           / Pixel ordering scheme: RING, NESTED, or NUNIQ
    COORDSYS= 'C       '           / Ecliptic, Galactic or Celestial (equatorial)
    MOCORDER=                    9 / MOC resolution (best order)
    INDXSCHM= 'EXPLICIT'           / Indexing: IMPLICIT or EXPLICIT
    VCSVERS = 'foo 1.0 '           / Software version
    VCSREV  = 'bar     '           / Software revision (Git)
    DATE-BLD= '2018-01-01T00:00:00' / Software build date

    """  # noqa: E501
    log.debug('normalizing metadata')
    if isinstance(m, Table) or (isinstance(m, np.ndarray) and m.dtype.names):
        m = Table(m, copy=False)
    else:
        if np.ndim(m) == 1:
            m = [m]
        m = Table(m, names=DEFAULT_NESTED_NAMES[:len(m)], copy=False)
    m.meta.update(kwargs)

    if 'UNIQ' in m.colnames:
        default_names = DEFAULT_NUNIQ_NAMES
        default_units = DEFAULT_NUNIQ_UNITS
        extra_header = [
            ('PIXTYPE', 'HEALPIX',
             'HEALPIX pixelisation'),
            ('ORDERING', 'NUNIQ',
             'Pixel ordering scheme: RING, NESTED, or NUNIQ'),
            ('COORDSYS', 'C',
             'Ecliptic, Galactic or Celestial (equatorial)'),
            ('MOCORDER', moc.uniq2order(m['UNIQ'].max()),
             'MOC resolution (best order)'),
            ('INDXSCHM', 'EXPLICIT',
             'Indexing: IMPLICIT or EXPLICIT')]
        # Ignore nest keyword argument if present
        m.meta.pop('nest', False)
    else:
        default_names = DEFAULT_NESTED_NAMES
        default_units = DEFAULT_NESTED_UNITS
        ordering = 'NESTED' if m.meta.pop('nest', False) else 'RING'
        extra_header = [
            ('PIXTYPE', 'HEALPIX',
             'HEALPIX pixelisation'),
            ('ORDERING', ordering,
             'Pixel ordering scheme: RING, NESTED, or NUNIQ'),
            ('COORDSYS', 'C',
             'Ecliptic, Galactic or Celestial (equatorial)'),
            ('NSIDE', ah.npix_to_nside(len(m)),
             'Resolution parameter of HEALPIX'),
            ('INDXSCHM', 'IMPLICIT',
             'Indexing: IMPLICIT or EXPLICIT')]

    for key, rows in itertools.groupby(FITS_META_MAPPING, lambda row: row[0]):
        try:
            value = m.meta.pop(key)
        except KeyError:
            pass
        else:
            for row in rows:
                _, fits_key, fits_comment, to_fits, _ = row
                if to_fits is not None:
                    extra_header.append(
                        (fits_key, to_fits(value), fits_comment))

    for default_name, default_unit in zip(default_names, default_units):
        try:
            col = m[default_name]
        except KeyError:
            pass
        else:
            if not col.unit:
                col.unit = default_unit

    log.debug('converting from Astropy table to FITS HDU list')
    hdu = fits.table_to_hdu(m)
    hdu.header.extend(extra_header)
    hdulist = fits.HDUList([fits.PrimaryHDU(), hdu])
    log.debug('saving')
    hdulist.writeto(filename, overwrite=True)


def read_sky_map(filename, nest=False, distances=False, moc=False, **kwargs):
    """Read a LIGO/Virgo/KAGRA-type sky map and return a tuple of the HEALPix
    array and a dictionary of metadata from the header.

    Parameters
    ----------
    filename: string
        Path to the optionally gzip-compressed FITS file.

    nest: bool, optional
        If omitted or False, then detect the pixel ordering in the FITS file
        and rearrange if necessary to RING indexing before returning.

        If True, then detect the pixel ordering and rearrange if necessary to
        NESTED indexing before returning.

        If None, then preserve the ordering from the FITS file.

        Regardless of the value of this option, the ordering used in the FITS
        file is indicated as the value of the 'nest' key in the metadata
        dictionary.

    distances: bool, optional
        If true, then read also read the additional HEALPix layers representing
        the conditional mean and standard deviation of distance as a function
        of sky location.

    moc: bool, optional
        If true, then preserve multi-order structure if present.

    Examples
    --------
    Test that we can read a legacy IDL-compatible file
    (https://bugs.ligo.org/redmine/issues/5168):

    >>> import tempfile
    >>> with tempfile.NamedTemporaryFile(suffix='.fits') as f:
    ...     nside = 512
    ...     npix = ah.nside_to_npix(nside)
    ...     ipix_nest = np.arange(npix)
    ...     hp.write_map(f.name, ipix_nest, nest=True, column_names=['PROB'])
    ...     m, meta = read_sky_map(f.name)
    ...     np.testing.assert_array_equal(m, hp.ring2nest(nside, ipix_nest))

    """
    m = Table.read(filename, format='fits', **kwargs)

    # Remove some keys that we do not need
    for key in (
            'PIXTYPE', 'EXTNAME', 'NSIDE', 'FIRSTPIX', 'LASTPIX', 'INDXSCHM',
            'MOCORDER'):
        m.meta.pop(key, None)

    if m.meta.pop('COORDSYS', 'C') != 'C':
        raise ValueError('ligo.skymap only reads and writes sky maps in '
                         'equatorial coordinates.')

    try:
        value = m.meta.pop('ORDERING')
    except KeyError:
        pass
    else:
        if value == 'RING':
            m.meta['nest'] = False
        elif value == 'NESTED':
            m.meta['nest'] = True
        elif value == 'NUNIQ':
            pass
        else:
            raise ValueError(
                'ORDERING card in header has unknown value: {0}'.format(value))

    for fits_key, rows in itertools.groupby(
            FITS_META_MAPPING, lambda row: row[1]):
        try:
            value = m.meta.pop(fits_key)
        except KeyError:
            pass
        else:
            for row in rows:
                key, _, _, _, from_fits = row
                if from_fits is not None:
                    m.meta[key] = from_fits(value)

    # FIXME: Fermi GBM HEALPix maps use the column name 'PROBABILITY',
    # instead of the LIGO/Virgo/KAGRA convention of 'PROB'.
    #
    # Fermi may change to our convention in the future, but for now we
    # rename the column.
    if 'PROBABILITY' in m.colnames:
        m.rename_column('PROBABILITY', 'PROB')

    # For a long time, we produced files with a UNIQ column that was an
    # unsigned integer. Cast it here to a signed integer so that the user
    # can handle old or new sky maps the same way.
    if 'UNIQ' in m.colnames:
        m['UNIQ'] = m['UNIQ'].astype(np.int64)

    if 'UNIQ' not in m.colnames:
        m = Table([col.ravel() for col in m.columns.values()], meta=m.meta)

    if 'UNIQ' in m.colnames and not moc:
        from ..bayestar import rasterize
        m = rasterize(m)
        m.meta['nest'] = True
    elif 'UNIQ' not in m.colnames and moc:
        from ..bayestar import derasterize
        if not m.meta['nest']:
            npix = len(m)
            nside = ah.npix_to_nside(npix)
            m = m[hp.nest2ring(nside, np.arange(npix))]
        m = derasterize(m)
        m.meta.pop('nest', None)

    if 'UNIQ' not in m.colnames:
        npix = len(m)
        nside = ah.npix_to_nside(npix)

        if nest is None:
            pass
        elif m.meta['nest'] and not nest:
            m = m[hp.ring2nest(nside, np.arange(npix))]
        elif not m.meta['nest'] and nest:
            m = m[hp.nest2ring(nside, np.arange(npix))]

    if moc:
        return m
    elif distances:
        return tuple(
            np.asarray(m[name]) for name in DEFAULT_NESTED_NAMES), m.meta
    else:
        return np.asarray(m[DEFAULT_NESTED_NAMES[0]]), m.meta


if __name__ == '__main__':
    import os
    nside = 128
    npix = ah.nside_to_npix(nside)
    prob = np.random.random(npix)
    prob /= sum(prob)

    write_sky_map(
        'test.fits.gz', prob,
        objid='FOOBAR 12345',
        gps_time=1049492268.25,
        creator=os.path.basename(__file__),
        url='http://www.youtube.com/watch?v=0ccKPSVQcFk',
        origin='LIGO Scientific Collaboration',
        runtime=21.5)

    print(read_sky_map('test.fits.gz'))

ligo/skymap/io/hdf5.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
# Copyright (C) 2016-2022  Leo Singer, John Veitch
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read HDF5 posterior sample chain HDF5 files."""

import numpy as np
import h5py
from astropy.table import Column, Table

# Constants from lalinference module
POSTERIOR_SAMPLES = 'posterior_samples'
LINEAR = 0
CIRCULAR = 1
FIXED = 2
OUTPUT = 3

__all__ = ('read_samples', 'write_samples')


def _identity(x):
    return x


_colname_map = (('rightascension', 'ra', _identity),
                ('right_ascension', 'ra', _identity),
                ('declination', 'dec', _identity),
                ('logdistance', 'dist', np.exp),
                ('distance', 'dist', _identity),
                ('luminosity_distance', 'dist', _identity),
                ('polarisation', 'psi', _identity),
                ('chirpmass', 'mc', _identity),
                ('chirp_mass', 'mc', _identity),
                ('a_spin1', 'a1', _identity),
                ('a_1', 'a1', _identity),
                ('a_spin2', 'a2', _identity),
                ('a_2', 'a2', _identity),
                ('tilt_spin1', 'tilt1', _identity),
                ('tilt_1', 'tilt1', _identity),
                ('tilt_spin2', 'tilt2', _identity),
                ('tilt_2', 'tilt2', _identity),
                ('geocent_time', 'time', _identity))


def _remap_colnames(table):
    for old_name, new_name, func in _colname_map:
        if old_name in table.colnames:
            table[new_name] = func(table.columns.pop(old_name))


def _find_table(group, tablename):
    """Recursively search an HDF5 group or file for a dataset by name.

    Parameters
    ----------
    group : `h5py.File` or `h5py.Group`
        The file or group to search
    tablename : str
        The name of the table to search for

    Returns
    -------
    dataset : `h5py.Dataset`
        The dataset whose name is `tablename`

    Raises
    ------
    KeyError
        If the table is not found or if multiple matching tables are found

    Examples
    --------
    Check that we can find a file by name:

    >>> import os.path
    >>> import tempfile
    >>> table = Table(np.eye(3), names=['a', 'b', 'c'])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'bar')
    <HDF5 dataset "bar": shape (3,), type "|V24">

    Check that an exception is raised if the table is not found:

    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'plugh')
    Traceback (most recent call last):
        ...
    KeyError: 'Table not found: plugh'

    Check that an exception is raised if multiple tables are found:

    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'bat')
    Traceback (most recent call last):
        ...
    KeyError: 'Multiple tables called bat exist: foo/bat, foo/xyzzy/bat'

    """
    results = {}

    def visitor(key, value):
        _, _, name = key.rpartition('/')
        if name == tablename:
            results[key] = value

    group.visititems(visitor)

    if len(results) == 0:
        raise KeyError('Table not found: {0}'.format(tablename))

    if len(results) > 1:
        raise KeyError('Multiple tables called {0} exist: {1}'.format(
            tablename, ', '.join(sorted(results.keys()))))

    table, = results.values()
    return table


def read_samples(filename, path=None, tablename=POSTERIOR_SAMPLES):
    """Read an HDF5 sample chain file.

    Parameters
    ----------
    filename : str
        The path of the HDF5 file on the filesystem.
    path : str, optional
        The path of the dataset within the HDF5 file.
    tablename : str, optional
        The name of table to search for recursively within the HDF5 file.
        By default, search for 'posterior_samples'.

    Returns
    -------
    chain : `astropy.table.Table`
        The sample chain as an Astropy table.

    Examples
    --------
    Test reading a file written using the Python API:

    >>> import os.path
    >>> import tempfile
    >>> table = Table([
    ...     Column(np.ones(10), name='foo', meta={'vary': FIXED}),
    ...     Column(np.arange(10), name='bar', meta={'vary': LINEAR}),
    ...     Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}),
    ...     Column(np.arange(10), name='baz', meta={'vary': OUTPUT})
    ... ])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     write_samples(table, filename, path='foo/bar/posterior_samples')
    ...     len(read_samples(filename))
    10

    Test reading a file that was written using the LAL HDF5 C API:

    >>> from importlib.resources import files
    >>> with files('ligo.skymap.io.tests.data').joinpath(
    ...         'test.hdf5').open('rb') as f:
    ...     table = read_samples(f)
    >>> table.colnames
    ['uvw', 'opq', 'lmn', 'ijk', 'def', 'abc', 'ghi', 'rst']

    """
    with h5py.File(filename, 'r') as f:
        if path is not None:  # Look for a given path
            table = f[path]
        else:  # Look for a given table name
            table = _find_table(f, tablename)
        table = Table.read(table)

    # Restore vary types.
    for i, column in enumerate(table.columns.values()):
        column.meta['vary'] = table.meta.get(
            'FIELD_{0}_VARY'.format(i), OUTPUT)

    # Restore fixed columns from table attributes.
    for key, value in table.meta.items():
        # Skip attributes from H5TB interface
        # (https://www.hdfgroup.org/HDF5/doc/HL/H5TB_Spec.html).
        if key == 'CLASS' or key == 'VERSION' or key == 'TITLE' or \
                key.startswith('FIELD_'):
            continue
        table.add_column(Column([value] * len(table), name=key,
                         meta={'vary': FIXED}))

    # Delete remaining table attributes.
    table.meta.clear()

    # Normalize column names.
    _remap_colnames(table)

    # Done!
    return table


def write_samples(table, filename, metadata=None, **kwargs):
    """Write an HDF5 sample chain file.

    Parameters
    ----------
    table : `astropy.table.Table`
        The sample chain as an Astropy table.
    filename : str
        The path of the HDF5 file on the filesystem.
    metadata: dict (optional)
        Dictionary of (path, value) pairs of metadata attributes
        to add to the output file
    kwargs: dict
        Any keyword arguments for `astropy.table.Table.write`.

    Examples
    --------
    Check that we catch columns that are supposed to be FIXED but are not:

    >>> table = Table([
    ...     Column(np.arange(10), name='foo', meta={'vary': FIXED})
    ... ])
    >>> write_samples(table, 'bar.hdf5', 'bat/baz')
    Traceback (most recent call last):
        ...
    AssertionError: 
    Arrays are not equal
    Column foo is a fixed column, but its values are not identical
    ...

    And now try writing an arbitrary example to a temporary file.

    >>> import os.path
    >>> import tempfile
    >>> table = Table([
    ...     Column(np.ones(10), name='foo', meta={'vary': FIXED}),
    ...     Column(np.arange(10), name='bar', meta={'vary': LINEAR}),
    ...     Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}),
    ...     Column(np.arange(10), name='baz', meta={'vary': OUTPUT}),
    ...     Column(np.ones(10), name='plugh'),
    ...     Column(np.arange(10), name='xyzzy')
    ... ])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     write_samples(
    ...         table, os.path.join(dir, 'test.hdf5'), path='bat/baz',
    ...         metadata={'bat/baz': {'widget': 'shoephone'}})

    """  # noqa: W291
    # Copy the table so that we do not modify the original.
    table = table.copy()

    # Make sure that all tables have a 'vary' type.
    for column in table.columns.values():
        if 'vary' not in column.meta:
            if np.all(column[0] == column[1:]):
                column.meta['vary'] = FIXED
            else:
                column.meta['vary'] = OUTPUT
    # Reconstruct table attributes.
    for colname, column in tuple(table.columns.items()):
        if column.meta['vary'] == FIXED:
            np.testing.assert_array_equal(column[1:], column[0],
                                          'Column {0} is a fixed column, but '
                                          'its values are not identical'
                                          .format(column.name))
            table.meta[colname] = column[0]
            del table[colname]
    for i, column in enumerate(table.columns.values()):
        table.meta['FIELD_{0}_VARY'.format(i)] = column.meta.pop('vary')
    table.write(filename, format='hdf5', **kwargs)
    if metadata:
        with h5py.File(filename, 'r+') as hdf:
            for internal_path, attributes in metadata.items():
                for key, value in attributes.items():
                    try:
                        hdf[internal_path].attrs[key] = value
                    except KeyError:
                        raise KeyError(
                            'Unable to set metadata {0}[{1}] = {2}'.format(
                                internal_path, key, value))

ligo/skymap/io/events/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/io/events/base.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Base classes for reading events from search pipelines."""

from abc import ABCMeta, abstractmethod
from collections.abc import Mapping

__all__ = ('EventSource', 'Event', 'SingleEvent')


def _fmt(obj, keys):
    kvs = ', '.join('{}={!r}'.format(key, getattr(obj, key)) for key in keys)
    return '<{}({})>'.format(obj.__class__.__name__, kvs)


class EventSource(Mapping):
    """Abstraction of a source of coincident events.

    This is a mapping from event IDs (which may be any hashable type, but are
    generally integers or strings) to instances of `Event`.
    """

    def __str__(self):
        try:
            length = len(self)
        except (NotImplementedError, TypeError):
            contents = '...'
        else:
            contents = '...{} items...'.format(length)
        return '<{}({{{}}})>'.format(self.__class__.__name__, contents)

    def __repr__(self):
        try:
            len(self)
        except NotImplementedError:
            contents = '...'
        else:
            contents = ', '.join('{}: {!r}'.format(key, value)
                                 for key, value in self.items())
        return '{}({{{}}})'.format(self.__class__.__name__, contents)


class Event(metaclass=ABCMeta):
    """Abstraction of a coincident trigger.

    Attributes
    ----------
    singles : list, tuple
        Sequence of `SingleEvent`
    template_args : dict
        Dictionary of template parameters

    """

    @property
    @abstractmethod
    def singles(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def template_args(self):
        raise NotImplementedError

    __str_keys = ('singles',)

    def __str__(self):
        return _fmt(self, self.__str_keys)

    __repr__ = __str__


class SingleEvent(metaclass=ABCMeta):
    """Abstraction of a single-detector trigger.

    Attributes
    ----------
    detector : str
        Instrument name (e.g. 'H1')
    snr : float
        Signal to noise ratio
    phase : float
        Phase on arrival
    time : float
        GPS time on arrival
    zerolag_time : float
        GPS time on arrival in zero-lag data, without time slides applied
    psd : `REAL8FrequencySeries`
        Power spectral density
    snr_series : `COMPLEX8TimeSeries`
        SNR time series

    """

    @property
    @abstractmethod
    def detector(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def snr(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def phase(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def time(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def zerolag_time(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def psd(self):
        raise NotImplementedError

    @property
    def snr_series(self):
        return None

    __str_keys = ('detector', 'snr', 'phase', 'time')

    def __str__(self):
        keys = self.__str_keys
        if self.time != self.zerolag_time:
            keys += ('zerolag_time',)
        return _fmt(self, keys)

    __repr__ = __str__

ligo/skymap/io/events/detector_disabled.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Modify events by artificially disabling specified detectors."""
from .base import Event, EventSource

__all__ = ('DetectorDisabledEventSource', 'DetectorDisabledError')


class DetectorDisabledError(ValueError):
    pass


class DetectorDisabledEventSource(EventSource):

    def __init__(self, base_source, disabled_detectors, raises=True):
        self.base_source = base_source
        self.disabled_detectors = set(disabled_detectors)
        self.raises = raises

    def __iter__(self):
        return iter(self.base_source)

    def __getitem__(self, key):
        return DetectorDisabledEvent(self, self.base_source[key])

    def __len__(self):
        return len(self.base_source)


class DetectorDisabledEvent(Event):

    def __init__(self, source, base_event):
        self.source = source
        self.base_event = base_event

    @property
    def singles(self):
        disabled_detectors = self.source.disabled_detectors
        if self.source.raises:
            detectors = {s.detector for s in self.base_event.singles}
            if not detectors & disabled_detectors:
                raise DetectorDisabledError(
                    'Disabling detectors {{{}}} would have no effect on this '
                    'event with detectors {{{}}}'.format(
                        ' '.join(sorted(disabled_detectors)),
                        ' '.join(sorted(detectors))))
            if not detectors - disabled_detectors:
                raise DetectorDisabledError(
                    'Disabling detectors {{{}}} would exclude all data for '
                    'this event with detectors {{{}}}'.format(
                        ' '.join(sorted(disabled_detectors)),
                        ' '.join(sorted(detectors))))
        return tuple(s for s in self.base_event.singles
                     if s.detector not in disabled_detectors)

    @property
    def template_args(self):
        return self.base_event.template_args


open = DetectorDisabledEventSource

ligo/skymap/io/events/gracedb.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
# Copyright (C) 2017-2022  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import logging

from ligo.gracedb import rest
from ligo.lw import ligolw

from .base import EventSource
from .ligolw import LigoLWEventSource, _read_xml

__all__ = ('GraceDBEventSource',)

log = logging.getLogger('BAYESTAR')


def _has_psds(xmldoc):
    for elem in xmldoc.getElementsByTagName(ligolw.LIGO_LW.tagName):
        if elem.hasAttribute('Name') and elem.Name == 'REAL8FrequencySeries':
            return True
    return False


class GraceDBEventSource(EventSource):
    """Read events from GraceDB.

    Parameters
    ----------
    graceids : list
        List of GraceDB ID strings.
    client : `ligo.gracedb.rest.GraceDb`, optional
        Client object

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, graceids, client=None):
        if client is None:
            client = rest.GraceDb()
        self._client = client
        self._graceids = graceids

    def __iter__(self):
        return iter(self._graceids)

    def __getitem__(self, graceid):
        coinc_file, _ = _read_xml(self._client.files(graceid, 'coinc.xml'))
        if _has_psds(coinc_file):
            psd_file = coinc_file
        else:
            log.warning('The coinc.xml should contain a PSD, but it does not. '
                        'Attempting to download psd.xml.gz.')
            psd_file = self._client.files(graceid, 'psd.xml.gz')
        event, = LigoLWEventSource(
            coinc_file, psd_file=psd_file, coinc_def=None).values()
        return event

    def __len__(self):
        return len(self._graceids)


open = GraceDBEventSource

ligo/skymap/io/events/hdf.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from PyCBC-style HDF5 output."""
from operator import itemgetter
from itertools import groupby

import h5py
import numpy as np
import lal
from ligo.segments import segment, segmentlist

from .base import Event, EventSource, SingleEvent

__all__ = ('HDFEventSource',)


class _psd_segment(segment):  # noqa: N801

    def __new__(cls, psd, *args):
        return segment.__new__(cls, *args)

    def __init__(self, psd, *args):
        self.psd = psd


def _hdf_file(f):
    if isinstance(f, h5py.File):
        return f
    elif hasattr(f, 'read') and hasattr(f, 'name'):
        return h5py.File(f.name, 'r')
    else:
        return h5py.File(f, 'r')


def _classify_hdf_file(f, sample):
    if sample in f:
        return 'coincs'
    for key, value in f.items():
        if isinstance(value, h5py.Group):
            if 'psds' in value:
                return 'psds'
            if 'snr' in value and 'coa_phase' in value and 'end_time' in value:
                return 'triggers'
    if 'parameters' in f.attrs:
        return 'bank'
    raise ValueError('Unrecognized PyCBC file type')


class HDFEventSource(EventSource):
    """Read events from PyCBC-style HDF5 files.

    Parameters
    ----------
    *files : list of str, file-like object, or `h5py.File` objects
        The PyCBC coinc, bank, psds, and triggers files, in any order.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, *files, **kwargs):
        sample = kwargs.get('sample', 'foreground')

        # Open the files and split them into coinc files, bank files, psds,
        # and triggers.
        key = itemgetter(0)
        files = [_hdf_file(f) for f in files]
        files = sorted(
            [(_classify_hdf_file(f, sample), f) for f in files], key=key)
        files = {key: list(v[1] for v in value)
                 for key, value in groupby(files, key)}

        try:
            coinc_file, = files['coincs']
        except (KeyError, ValueError):
            raise ValueError('You must provide exactly one coinc file.')
        try:
            bank_file, = files['bank']
        except (KeyError, ValueError):
            raise ValueError(
                'You must provide exactly one template bank file.')
        try:
            psd_files = files['psds']
        except KeyError:
            raise ValueError('You must provide PSD files.')
        try:
            trigger_files = files['triggers']
        except KeyError:
            raise ValueError('You must provide trigger files.')

        self._bank = bank_file

        key_prefix = 'detector_'
        detector_nums, self._ifos = zip(*sorted(
            (int(key[len(key_prefix):]), value)
            for key, value in coinc_file.attrs.items()
            if key.startswith(key_prefix)))

        coinc_group = coinc_file[sample]
        self._timeslide_interval = coinc_file.attrs.get(
            'timeslide_interval', 0)
        self._template_ids = coinc_group['template_id']
        self._timeslide_ids = coinc_group.get(
            'timeslide_id', np.zeros(len(self)))
        self._trigger_ids = [
            coinc_group['trigger_id{}'.format(detector_num)]
            for detector_num in detector_nums]

        triggers = {}
        for f in trigger_files:
            (ifo, group), = f.items()
            triggers[ifo] = [
                group['snr'], group['coa_phase'], group['end_time']]
        self._triggers = tuple(triggers[ifo] for ifo in self._ifos)

        psdseglistdict = {}
        for psd_file in psd_files:
            (ifo, group), = psd_file.items()
            psd = [group['psds'][str(i)] for i in range(len(group['psds']))]
            psdseglistdict[ifo] = segmentlist(
                _psd_segment(*segargs) for segargs in zip(
                    psd, group['start_time'], group['end_time']))
        self._psds = [psdseglistdict[ifo] for ifo in self._ifos]

    def __getitem__(self, id):
        return HDFEvent(self, id)

    def __iter__(self):
        return iter(range(len(self)))

    def __len__(self):
        return len(self._template_ids)


class HDFEvent(Event):

    def __init__(self, source, id):
        self._source = source
        self._id = id

    @property
    def singles(self):
        return tuple(
            HDFSingleEvent(
                ifo, self._id, i, trigger_ids[self._id],
                self._source._timeslide_interval, triggers,
                self._source._timeslide_ids, psds
            )
            for i, (ifo, trigger_ids, triggers, psds) in enumerate(zip(
                self._source._ifos, self._source._trigger_ids,
                self._source._triggers, self._source._psds
            ))
        )

    @property
    def template_args(self):
        bank = self._source._bank
        bank_id = self._source._template_ids[self._id]
        return {key: value[bank_id] for key, value in bank.items()}


class HDFSingleEvent(SingleEvent):

    def __init__(
            self, detector, _coinc_id, _detector_num, _trigger_id,
            _timeslide_interval, _triggers, _timeslide_ids, _psds):
        self._detector = detector
        self._coinc_id = _coinc_id
        self._detector_num = _detector_num
        self._trigger_id = _trigger_id
        self._timeslide_interval = _timeslide_interval
        self._triggers = _triggers
        self._timeslide_ids = _timeslide_ids
        self._psds = _psds

    @property
    def detector(self):
        return self._detector

    @property
    def snr(self):
        return self._triggers[0][self._trigger_id]

    @property
    def phase(self):
        return self._triggers[1][self._trigger_id]

    @property
    def time(self):
        value = self.zerolag_time

        # PyCBC does not specify which detector is time-shifted in time slides.
        # Since PyCBC's coincidence format currently supports only two
        # detectors, we arbitrarily apply half of the time slide to each
        # detector.
        shift = self._timeslide_ids[self._coinc_id] * self._timeslide_interval
        if self._detector_num == 0:
            value -= 0.5 * shift
        elif self._detector_num == 1:
            value += 0.5 * shift
        else:
            raise AssertionError('This line should not be reached')
        return value

    @property
    def zerolag_time(self):
        return self._triggers[2][self._trigger_id]

    @property
    def psd(self):
        try:
            psd = self._psds[self._psds.find(self.zerolag_time)].psd
        except ValueError:
            raise ValueError(
                'No PSD found for detector {} at zero-lag GPS time {}'.format(
                    self.detector, self.zerolag_time))

        dyn_range_fac = psd.file.attrs['dynamic_range_factor']
        flow = psd.file.attrs['low_frequency_cutoff']
        df = psd.attrs['delta_f']
        kmin = int(flow / df)

        fseries = lal.CreateREAL8FrequencySeries(
            'psd', 0, kmin * df, df,
            lal.DimensionlessUnit, len(psd) - kmin)
        fseries.data.data = psd[kmin:] / np.square(dyn_range_fac)
        return fseries


open = HDFEventSource

ligo/skymap/io/events/ligolw.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
# Copyright (C) 2017-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from pipedown/GstLal-style XML output."""
from collections import defaultdict
import errno
from functools import lru_cache
import itertools
import logging
import operator
import os

from ligo.lw import array, lsctables, param
from ligo.lw.ligolw import Element, LIGOLWContentHandler, LIGO_LW
from ligo.lw.lsctables import (
    CoincDefTable, CoincMapTable, CoincTable, ProcessTable, ProcessParamsTable,
    SnglInspiralTable, TimeSlideID, TimeSlideTable)
from ligo.lw.utils import load_filename, load_fileobj
import lal
import lal.series
from lalinspiral.thinca import InspiralCoincDef

from .base import Event, EventSource, SingleEvent
from ...util import ilwd

__all__ = ('LigoLWEventSource',)

log = logging.getLogger('BAYESTAR')


@ilwd.use_in
@array.use_in
@lsctables.use_in
@param.use_in
class ContentHandler(LIGOLWContentHandler):
    pass


def _read_xml(f, fallbackpath=None):
    if f is None:
        doc = filename = None
    elif isinstance(f, Element):
        doc = f
        filename = ''
    elif isinstance(f, (str, os.PathLike)):
        try:
            doc = load_filename(f, contenthandler=ContentHandler)
        except IOError as e:
            if e.errno == errno.ENOENT and fallbackpath and \
                    not os.path.isabs(f):
                f = os.path.join(fallbackpath, f)
                doc = load_filename(f, contenthandler=ContentHandler)
            else:
                raise
        filename = f
    else:
        doc = load_fileobj(f, contenthandler=ContentHandler)
        try:
            filename = f.name
        except AttributeError:
            filename = ''
    return doc, filename


class LigoLWEventSource(dict, EventSource):
    """Read events from LIGO-LW XML files.

    Parameters
    ----------
    f : str, file-like object, or `ligo.lw.ligolw.Document`
        The name of the file, or the file object, or the XML document object,
        containing the trigger tables.
    psd_file : str, file-like object, or `ligo.lw.ligolw.Document`
        The name of the file, or the file object, or the XML document object,
        containing the PSDs. If not supplied, then PSDs will be read
        automatically from any files mentioned in ``--reference-psd`` command
        line arguments stored in the ProcessParams table.
    coinc_def : `ligo.lw.lsctables.CoincDef`, optional
        An optional coinc definer to limit which events are read.
    fallbackpath : str, optional
        A directory to search for PSD files whose ``--reference-psd`` command
        line arguments have relative paths. By default, the current working
        directory and the directory containing file ``f`` are searched.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, f, psd_file=None, coinc_def=InspiralCoincDef,
                 fallbackpath=None, **kwargs):
        doc, filename = _read_xml(f)
        self._fallbackpath = (
            os.path.dirname(filename) if filename else fallbackpath)
        self._psds_for_file = lru_cache(maxsize=None)(self._psds_for_file)
        super().__init__(self._make_events(doc, psd_file, coinc_def))

    def __str__(self):
        contents = repr(self)
        return '<{}>'.format(contents)

    def __repr__(self):
        contents = super().__repr__()
        return '{}({})'.format(self.__class__.__name__, contents)

    _template_keys = '''mass1 mass2
                        spin1x spin1y spin1z spin2x spin2y spin2z
                        f_final'''.split()

    _invert_phases = {
        'pycbc': False,
        'gstlal_inspiral': True,
        'gstlal_inspiral_postcohspiir_online': True,  # FIXME: wild guess
        'bayestar_realize_coincs': True,
        'bayestar-realize-coincs': True,
        'MBTAOnline': True
    }

    @classmethod
    def _phase_convention(cls, program):
        try:
            return cls._invert_phases[program]
        except KeyError:
            raise KeyError(
                ('The pipeline "{}" is unknown, so the phase '
                 'convention could not be deduced.').format(program))

    def _psds_for_file(self, f):
        doc, _ = _read_xml(f, self._fallbackpath)
        return lal.series.read_psd_xmldoc(doc, root_name=None)

    def _make_events(self, doc, psd_file, coinc_def):
        # Look up necessary tables.
        coinc_table = CoincTable.get_table(doc)
        coinc_map_table = CoincMapTable.get_table(doc)
        sngl_inspiral_table = SnglInspiralTable.get_table(doc)
        try:
            time_slide_table = TimeSlideTable.get_table(doc)
        except ValueError:
            offsets_by_time_slide_id = None
        else:
            offsets_by_time_slide_id = time_slide_table.as_dict()

        # Indices to speed up lookups by ID.
        key = operator.attrgetter('coinc_event_id')
        event_ids_by_coinc_event_id = {
            coinc_event_id:
                tuple(coinc_map.event_id for coinc_map in coinc_maps
                      if coinc_map.table_name == SnglInspiralTable.tableName)
            for coinc_event_id, coinc_maps
            in itertools.groupby(sorted(coinc_map_table, key=key), key=key)}
        sngl_inspirals_by_event_id = {
            row.event_id: row for row in sngl_inspiral_table}

        # Filter rows by coinc_def if requested.
        if coinc_def is not None:
            coinc_def_table = CoincDefTable.get_table(doc)
            coinc_def_ids = {
                row.coinc_def_id for row in coinc_def_table
                if (row.search, row.search_coinc_type) ==
                (coinc_def.search, coinc_def.search_coinc_type)}
            coinc_table = [
                row for row in coinc_table
                if row.coinc_def_id in coinc_def_ids]

        snr_dict = dict(self._snr_series_by_sngl_inspiral(doc))

        process_table = ProcessTable.get_table(doc)
        program_for_process_id = {
            row.process_id: row.program for row in process_table}

        try:
            process_params_table = ProcessParamsTable.get_table(doc)
        except ValueError:
            psd_filenames_by_process_id = {}
        else:
            psd_filenames_by_process_id = {
                process_param.process_id: process_param.value
                for process_param in process_params_table
                if process_param.param == '--reference-psd'}

        ts0 = TimeSlideID(0)
        for time_slide_id in {coinc.time_slide_id for coinc in coinc_table}:
            if offsets_by_time_slide_id is None and time_slide_id == ts0:
                log.warning(
                    'Time slide record is missing for %s, '
                    'guessing that this is zero-lag', time_slide_id)

        for program in {program_for_process_id[coinc.process_id]
                        for coinc in coinc_table}:
            invert_phases = self._phase_convention(program)
            if invert_phases:
                log.warning(
                    'Using anti-FINDCHIRP phase convention; inverting phases. '
                    'This is currently the default and it is appropriate for '
                    'gstlal and MBTA but not pycbc as of observing run 1 '
                    '("O1"). The default setting is likely to change in the '
                    'future.')

        for coinc in coinc_table:
            coinc_event_id = coinc.coinc_event_id
            coinc_event_num = int(coinc_event_id)
            sngls = [sngl_inspirals_by_event_id[event_id] for event_id
                     in event_ids_by_coinc_event_id[coinc_event_id]]
            if offsets_by_time_slide_id is None and coinc.time_slide_id == ts0:
                offsets = defaultdict(float)
            else:
                offsets = offsets_by_time_slide_id[coinc.time_slide_id]

            template_args = [
                {key: getattr(sngl, key) for key in self._template_keys}
                for sngl in sngls]
            if any(d != template_args[0] for d in template_args[1:]):
                raise ValueError(
                    'Template arguments are not identical for all detectors!')
            template_args = template_args[0]

            invert_phases = self._phase_convention(
                program_for_process_id[coinc.process_id])

            singles = tuple(LigoLWSingleEvent(
                self, sngl.ifo, sngl.snr, sngl.coa_phase,
                float(sngl.end + offsets[sngl.ifo]), float(sngl.end),
                psd_file or psd_filenames_by_process_id.get(sngl.process_id),
                snr_dict.get(sngl.event_id), invert_phases)
                for sngl in sngls)

            event = LigoLWEvent(coinc_event_num, singles, template_args)

            yield coinc_event_num, event

    @classmethod
    def _snr_series_by_sngl_inspiral(cls, doc):
        for elem in doc.getElementsByTagName(LIGO_LW.tagName):
            try:
                if elem.Name != lal.COMPLEX8TimeSeries.__name__:
                    continue
                array.get_array(elem, 'snr')
                event_id = param.get_pyvalue(elem, 'event_id')
            except (AttributeError, ValueError):
                continue
            else:
                yield event_id, lal.series.parse_COMPLEX8TimeSeries(elem)


class LigoLWEvent(Event):

    def __init__(self, id, singles, template_args):
        self._id = id
        self._singles = singles
        self._template_args = template_args

    @property
    def singles(self):
        return self._singles

    @property
    def template_args(self):
        return self._template_args


class LigoLWSingleEvent(SingleEvent):

    def __init__(self, source, detector, snr, phase, time, zerolag_time,
                 psd_file, snr_series, invert_phases):
        self._source = source
        self._detector = detector
        self._snr = snr
        self._phase = phase
        self._time = time
        self._zerolag_time = zerolag_time
        self._psd_file = psd_file
        self._snr_series = snr_series
        self._invert_phases = invert_phases

    @property
    def detector(self):
        return self._detector

    @property
    def snr(self):
        return self._snr

    @property
    def phase(self):
        value = self._phase
        if value is not None and self._invert_phases:
            value *= -1
        return value

    @property
    def time(self):
        return self._time

    @property
    def zerolag_time(self):
        return self._zerolag_time

    @property
    def psd(self):
        return self._source._psds_for_file(self._psd_file)[self._detector]

    @property
    def snr_series(self):
        value = self._snr_series
        if self._invert_phases and value is not None:
            value = lal.CutCOMPLEX8TimeSeries(value, 0, len(value.data.data))
            value.data.data = value.data.data.conj()
        return value


open = LigoLWEventSource

ligo/skymap/io/events/magic.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from either HDF or LIGO-LW files."""
import builtins
import sqlite3

from ligo.lw.ligolw import Element
import h5py

from . import hdf, ligolw, sqlite

__all__ = ('MagicEventSource', 'open')


def _read_file_header(f, nbytes=16):
    """Read the first 16 bytes of a file

    This is presumed to include the characters that declare the
    file type.

    Parameters
    ----------
    f : file, str
        A file object or the path to a file

    Returns
    -------
    header : bytes
        A string (hopefully) describing the file type

    """
    try:
        pos = f.tell()
    except AttributeError:
        with builtins.open(f, "rb") as fobj:
            return fobj.read(nbytes)
    try:
        return f.read(nbytes)
    finally:
        f.seek(pos)


def MagicEventSource(f, *args, **kwargs):  # noqa: N802
    """Read events from LIGO-LW XML, LIGO-LW SQlite, or HDF5 files. The format
    is determined automatically using the :manpage:`file(1)` command, and then
    the file is opened using :obj:`.ligolw.open`, :obj:`.sqlite.open`, or
    :obj:`.hdf.open`, as appropriate.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """
    if isinstance(f, h5py.File):
        opener = hdf.open
    elif isinstance(f, sqlite3.Connection):
        opener = sqlite.open
    elif isinstance(f, Element):
        opener = ligolw.open
    else:
        fileheader = _read_file_header(f)
        if fileheader.startswith(b'\x89HDF\r\n\x1a\n'):
            opener = hdf.open
        elif fileheader.startswith(b'SQLite format 3'):
            opener = sqlite.open
        elif fileheader.startswith((
                b'<?xml',  # XML
                b'\x1f\x8b\x08',  # GZIP
        )):
            opener = ligolw.open
        else:
            raise IOError('Unknown file format')
    return opener(f, *args, **kwargs)


open = MagicEventSource

ligo/skymap/io/events/sqlite.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from a GstLal-style SQLite output."""
import os
import sqlite3

from ligo.lw import dbtables

from ...util import sqlite
from .ligolw import LigoLWEventSource

__all__ = ('SQLiteEventSource',)


class SQLiteEventSource(LigoLWEventSource):
    """Read events from LIGO-LW SQLite files.

    Parameters
    ----------
    f : str, file-like object, or `sqlite3.Connection` instance
        The SQLite database.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, f, *args, **kwargs):
        if isinstance(f, sqlite3.Connection):
            db = f
            filename = sqlite.get_filename(f)
        else:
            if hasattr(f, 'read'):
                filename = f.name
                f.close()
            else:
                filename = f
            db = sqlite.open(filename, 'r')
        super().__init__(dbtables.get_xml(db), *args, **kwargs)
        self._fallbackpath = os.path.dirname(filename) if filename else None


open = SQLiteEventSource

ligo/skymap/plot/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/plot/allsky.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
566  
567  
568  
569  
570  
571  
572  
573  
574  
575  
576  
577  
578  
579  
580  
581  
582  
583  
584  
585  
586  
587  
588  
589  
590  
591  
592  
593  
594  
595  
596  
597  
598  
599  
600  
601  
602  
603  
604  
605  
606  
607  
608  
609  
610  
611  
612  
613  
614  
615  
616  
617  
618  
619  
620  
621  
622  
623  
624  
625  
626  
627  
628  
629  
630  
631  
632  
633  
634  
635  
636  
637  
638  
639  
640  
641  
642  
643  
644  
645  
646  
647  
648  
649  
650  
651  
652  
653  
654  
655  
656  
657  
658  
659  
660  
661  
662  
663  
664  
665  
666  
667  
668  
669  
670  
671  
672  
673  
674  
675  
676  
677  
678  
679  
680  
681  
682  
683  
684  
685  
686  
687  
688  
689  
690  
691  
692  
693  
694  
695  
696  
697  
698  
699  
700  
701  
702  
703  
704  
705  
706  
707  
708  
709  
710  
711  
712  
713  
714  
715  
716  
717  
718  
719  
720  
721  
722  
723  
724  
725  
726  
727  
728  
729  
730  
731  
732  
733  
734  
735  
736  
737  
738  
739  
740  
741  
742  
743  
744  
745  
746  
747  
748  
749  
750  
751  
752  
753  
754  
755  
756  
757  
758  
759  
760  
761  
762  
763  
764  
765  
766  
767  
768  
769  
770  
771  
772  
773  
774  
775  
776  
777  
778  
779  
780  
781  
782  
783  
784  
785  
786  
787  
788  
789  
790  
791  
792  
793  
794  
795  
796  
797  
798  
799  
800  
801  
802  
803  
804  
805  
806  
807  
808  
809  
810  
811  
812  
813  
814  
815  
816  
817  
818  
819  
820  
821  
822  
823  
824  
825  
826  
827  
828  
829  
830  
831  
832  
833  
834  
835  
836  
837  
838  
839  
840  
841  
842  
843  
844  
845  
846  
847  
848  
849  
850  
851  
852  
#
# Copyright (C) 2012-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Axes subclasses for astronomical mapmaking.

This module adds several :class:`astropy.visualization.wcsaxes.WCSAxes`
subclasses to the Matplotlib projection registry. The projections have names of
the form :samp:`{astro_or_geo_or_galactic} [{lon_units}] {projection}`.

:samp:`{astro_or_geo_or_galactic}` may be ``astro``, ``geo``, or ``galactic``.
It controls the reference frame, either celestial (ICRS), terrestrial (ITRS),
or galactic.

:samp:`{lon_units}` may be ``hours`` or ``degrees``. It controls the units of
the longitude axis. If omitted, ``astro`` implies ``hours`` and ``geo`` implies
degrees.

:samp:`{projection}` may be any of the following:

* ``aitoff`` for the Aitoff all-sky projection

* ``mollweide`` for the Mollweide all-sky projection

* ``globe`` for an orthographic projection, like the three-dimensional view of
  the Earth from a distant satellite

* ``zoom`` for a gnomonic projection suitable for visualizing small zoomed-in
  patches

All projections support the ``center`` argument, while some support additional
arguments. The ``globe`` projections also support the ``rotate`` argument, and
the ``zoom`` projections also supports the ``radius`` and ``rotate`` arguments.

Examples
--------
.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='astro hours mollweide')
    ax.grid()

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='geo aitoff')
    ax.grid()

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='astro zoom',
                  center='5h -32d', radius='5 deg', rotate='20 deg')
    ax.grid()

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='geo globe', center='-50d +23d')
    ax.grid()

Insets
------
You can use insets to link zoom-in views between axes. There are two supported
styles of insets: rectangular and circular (loupe). The example below shows
both kinds of insets.

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    fig = plt.figure(figsize=(9, 4), dpi=100)

    ax_globe = plt.axes(
        [0.1, 0.1, 0.8, 0.8],
        projection='astro degrees globe',
        center='120d +23d')

    ax_zoom_rect = plt.axes(
        [0.0, 0.2, 0.4, 0.4],
        projection='astro degrees zoom',
        center='150d +30d',
        radius='9 deg')

    ax_zoom_circle = plt.axes(
        [0.55, 0.1, 0.6, 0.6],
        projection='astro degrees zoom',
        center='120d +10d',
        radius='5 deg')

    ax_globe.mark_inset_axes(ax_zoom_rect)
    ax_globe.connect_inset_axes(ax_zoom_rect, 'upper left')
    ax_globe.connect_inset_axes(ax_zoom_rect, 'lower right')

    ax_globe.mark_inset_circle(ax_zoom_circle, '120d +10d', '4 deg')
    ax_globe.connect_inset_circle(ax_zoom_circle, '120d +10d', '4 deg')

    ax_globe.grid()
    ax_zoom_rect.grid()
    ax_zoom_circle.grid()

    for ax in [ax_globe, ax_zoom_rect, ax_zoom_circle]:
        ax.set_facecolor('none')
        for key in ['ra', 'dec']:
            ax.coords[key].set_auto_axislabel(False)

Complete Example
----------------
The following example demonstrates most of the features of this module.

.. plot::
   :context: reset
   :include-source:
   :align: center

    from astropy.coordinates import SkyCoord
    from astropy.io import fits
    from astropy import units as u
    import ligo.skymap.plot
    from matplotlib import pyplot as plt

    url = 'https://dcc.ligo.org/public/0146/G1701985/001/bayestar_no_virgo.fits.gz'
    center = SkyCoord.from_name('NGC 4993')

    fig = plt.figure(figsize=(4, 4), dpi=100)

    ax = plt.axes(
        [0.05, 0.05, 0.9, 0.9],
        projection='astro globe',
        center=center)

    ax_inset = plt.axes(
        [0.59, 0.3, 0.4, 0.4],
        projection='astro zoom',
        center=center,
        radius=10*u.deg)

    for key in ['ra', 'dec']:
        ax_inset.coords[key].set_ticklabel_visible(False)
        ax_inset.coords[key].set_ticks_visible(False)
    ax.grid()
    ax.mark_inset_axes(ax_inset)
    ax.connect_inset_axes(ax_inset, 'upper left')
    ax.connect_inset_axes(ax_inset, 'lower left')
    ax_inset.scalebar((0.1, 0.1), 5 * u.deg).label()
    ax_inset.compass(0.9, 0.1, 0.2)

    ax.imshow_hpx(url, cmap='cylon')
    ax_inset.imshow_hpx(url, cmap='cylon')
    ax_inset.plot(
        center.ra.deg, center.dec.deg,
        transform=ax_inset.get_transform('world'),
        marker=ligo.skymap.plot.reticle(),
        markersize=30,
        markeredgewidth=3)

"""  # noqa: E501
from itertools import product

from astropy.coordinates import SkyCoord, UnitSphericalRepresentation
from astropy.io.fits import Header
from astropy.time import Time
from astropy.visualization.wcsaxes import SphericalCircle, WCSAxes
from astropy.visualization.wcsaxes.formatter_locator import (
    AngleFormatterLocator)
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from astropy.wcs import WCS
from astropy import units as u
from matplotlib import rcParams
from matplotlib.offsetbox import AnchoredOffsetbox
from matplotlib.patches import ConnectionPatch, FancyArrowPatch, PathPatch
from matplotlib.path import Path
from matplotlib.projections import projection_registry
import numpy as np
from reproject import reproject_from_healpix
from scipy.optimize import minimize_scalar
from .angle import reference_angle_deg, wrapped_angle_deg

__all__ = ['AutoScaledWCSAxes', 'ScaleBar']


class WCSInsetPatch(PathPatch):
    """Subclass of `matplotlib.patches.PathPatch` for marking the outline of
    one `astropy.visualization.wcsaxes.WCSAxes` inside another.
    """

    def __init__(self, ax, *args, **kwargs):
        self._ax = ax
        super().__init__(
            None, *args, fill=False,
            edgecolor=ax.coords.frame.get_color(),
            linewidth=ax.coords.frame.get_linewidth(),
            **kwargs)

    def get_path(self):
        frame = self._ax.coords.frame
        return frame.patch.get_path().interpolated(50).transformed(
            frame.transform)


class WCSInsetConnectionPatch(ConnectionPatch):
    """Patch to connect an inset WCS axes inside another WCS axes.

    Notes
    -----
    FIXME: This class assumes that the projection of the circle in figure-inch
    coordinates *is* a circle. It will have noticable artifacts if the
    projection is very distorted."""

    _corners_map = {1: 3, 2: 1, 3: 0, 4: 2}

    def __init__(self, ax, ax_inset, loc, **kwargs):
        try:
            loc = AnchoredOffsetbox.codes[loc]
        except KeyError:
            loc = int(loc)
        corners = ax_inset.viewLim.corners()
        transform = (ax_inset.coords.frame.transform +
                     ax.coords.frame.transform.inverted())
        xy_inset = corners[self._corners_map[loc]]
        xy = transform.transform_point(xy_inset)
        super().__init__(
            xy, xy_inset, 'data', 'data', axesA=ax, axesB=ax_inset,
            color=ax_inset.coords.frame.get_color(),
            linewidth=ax_inset.coords.frame.get_linewidth(),
            **kwargs)


class WCSCircleInsetConnectionPatch(PathPatch):
    """Patch to connect a circular inset WCS axes inside another WCS axes."""

    def __init__(self, ax1, ax2, coord, radius, sign, *args, **kwargs):
        self._axs = (ax1, ax2)
        self._coord = coord.icrs
        self._radius = radius
        self._sign = sign
        super().__init__(None, *args, **kwargs, clip_on=False, transform=None)

    def get_path(self):
        # Calculate the position and radius of the inset in figure-inch
        # coordinates.
        offset = self._coord.directional_offset_by(0 * u.deg, self._radius)
        transforms = [ax.get_transform('world') for ax in self._axs]
        centers = np.asarray([
            tx.transform_point((self._coord.ra.deg, self._coord.dec.deg))
            for tx in transforms])
        offsets = np.asarray([
            tx.transform_point((offset.ra.deg, offset.dec.deg))
            for tx in transforms])

        # Plot outer tangents.
        r0, r1 = np.sqrt(np.sum(np.square(centers - offsets), axis=-1))
        dx, dy = np.diff(centers, axis=0).ravel()
        gamma = -np.arctan(dy / dx)
        beta = np.arcsin((r1 - r0) / np.sqrt(np.square(dx) + np.square(dy)))
        alpha = gamma - self._sign * beta
        p0 = centers[0] + self._sign * np.asarray([
            r0 * np.sin(alpha), r0 * np.cos(alpha)])
        p1 = centers[1] + self._sign * np.asarray([
            r1 * np.sin(alpha), r1 * np.cos(alpha)])
        return Path(np.row_stack((p0, p1)), np.asarray([
            Path.MOVETO, Path.LINETO]))


class AutoScaledWCSAxes(WCSAxes):
    """Axes base class. The pixel scale is adjusted to the DPI of the image,
    and there are a variety of convenience methods.
    """

    name = 'astro wcs'

    def __init__(self, *args, header, obstime=None, **kwargs):
        super().__init__(*args, aspect=1, **kwargs)
        h = Header(header, copy=True)
        naxis1 = h['NAXIS1']
        naxis2 = h['NAXIS2']
        scale = min(self.bbox.width / naxis1, self.bbox.height / naxis2)
        h['NAXIS1'] = int(np.ceil(naxis1 * scale))
        h['NAXIS2'] = int(np.ceil(naxis2 * scale))
        scale1 = h['NAXIS1'] / naxis1
        scale2 = h['NAXIS2'] / naxis2
        h['CRPIX1'] = (h['CRPIX1'] - 1) * (h['NAXIS1'] - 1) / (naxis1 - 1) + 1
        h['CRPIX2'] = (h['CRPIX2'] - 1) * (h['NAXIS2'] - 1) / (naxis2 - 1) + 1
        h['CDELT1'] /= scale1
        h['CDELT2'] /= scale2
        if obstime is not None:
            h['MJD-OBS'] = Time(obstime).utc.mjd
            h['DATE-OBS'] = Time(obstime).utc.isot
        self.reset_wcs(WCS(h))
        self.set_xlim(-0.5, h['NAXIS1'] - 0.5)
        self.set_ylim(-0.5, h['NAXIS2'] - 0.5)
        self._header = h

    @property
    def header(self):
        return self._header

    def mark_inset_axes(self, ax, *args, **kwargs):
        """Outline the footprint of another WCSAxes inside this one.

        Parameters
        ----------
        ax : `astropy.visualization.wcsaxes.WCSAxes`
            The other axes.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.PathPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.PathPatch`

        Returns
        -------
        patch : `matplotlib.patches.PathPatch`

        """
        return self.add_patch(WCSInsetPatch(
            ax, *args, transform=self.get_transform('world'), **kwargs))

    def mark_inset_circle(self, ax, center, radius, *args, **kwargs):
        """Outline a circle in this and another Axes to create a loupe.

        Parameters
        ----------
        ax : `astropy.visualization.wcsaxes.WCSAxes`
            The other axes.
        coord : `astropy.coordinates.SkyCoord`
            The center of the circle.
        radius : `astropy.units.Quantity`
            The radius of the circle in units that are compatible with degrees.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.PathPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.PathPatch`

        Returns
        -------
        patch1 : `matplotlib.patches.PathPatch`
            The outline of the circle in these Axes.
        patch2 : `matplotlib.patches.PathPatch`
            The outline of the circle in the other Axes.
        """
        center = SkyCoord(
            center, representation_type=UnitSphericalRepresentation).icrs
        radius = u.Quantity(radius)
        args = ((center.ra, center.dec), radius, *args)
        kwargs = {'facecolor': 'none',
                  'edgecolor': rcParams['axes.edgecolor'],
                  'linewidth': rcParams['axes.linewidth'],
                  **kwargs}
        for ax in (self, ax):
            ax.add_patch(SphericalCircle(*args, **kwargs,
                                         transform=ax.get_transform('world')))

    def connect_inset_axes(self, ax, loc, *args, **kwargs):
        """Connect a corner of another WCSAxes to the matching point inside
        this one.

        Parameters
        ----------
        ax : `astropy.visualization.wcsaxes.WCSAxes`
            The other axes.
        loc : int, str
            Which corner to connect. For valid values, see
            `matplotlib.offsetbox.AnchoredOffsetbox`.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.ConnectionPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.ConnectionPatch`

        Returns
        -------
        patch : `matplotlib.patches.ConnectionPatch`

        """
        return self.add_patch(WCSInsetConnectionPatch(
            self, ax, loc, *args, **kwargs))

    def connect_inset_circle(self, ax, center, radius, *args, **kwargs):
        """Connect a circle in this and another Axes to create a loupe.

        Parameters
        ----------
        ax : `astropy.visualization.wcsaxes.WCSAxes`
            The other axes.
        coord : `astropy.coordinates.SkyCoord`
            The center of the circle.
        radius : `astropy.units.Quantity`
            The radius of the circle in units that are compatible with degrees.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.PathPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.PathPatch`

        Returns
        -------
        patch1, patch2 : `matplotlib.patches.ConnectionPatch`
            The two connecting patches.
        """
        center = SkyCoord(
            center, representation_type=UnitSphericalRepresentation).icrs
        radius = u.Quantity(radius)
        kwargs = {'color': rcParams['axes.edgecolor'],
                  'linewidth': rcParams['axes.linewidth'],
                  **kwargs}
        for sign in (-1, 1):
            self.add_patch(WCSCircleInsetConnectionPatch(
                self, ax, center, radius, sign, *args, **kwargs))

    def compass(self, x, y, size):
        """Add a compass to indicate the north and east directions.

        Parameters
        ----------
        x, y : float
            Position of compass vertex in axes coordinates.
        size : float
            Size of compass in axes coordinates.

        """
        xy = x, y
        scale = self.wcs.pixel_scale_matrix
        scale /= np.sqrt(np.abs(np.linalg.det(scale)))
        return [self.annotate(label, xy, xy + size * n,
                              self.transAxes, self.transAxes,
                              ha='center', va='center',
                              arrowprops=dict(arrowstyle='<-',
                                              shrinkA=0.0, shrinkB=0.0))
                for n, label, ha, va in zip(scale, 'EN',
                                            ['right', 'center'],
                                            ['center', 'bottom'])]

    def scalebar(self, *args, **kwargs):
        """Add scale bar.

        Parameters
        ----------
        xy : tuple
            The axes coordinates of the scale bar.
        length : `astropy.units.Quantity`
            The length of the scale bar in angle-compatible units.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.FancyArrowPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.FancyArrowPatch`

        Returns
        -------
        patch : `matplotlib.patches.FancyArrowPatch`

        """
        return self.add_patch(ScaleBar(self, *args, **kwargs))

    def _reproject_hpx(self, data, hdu_in=None, order='bilinear',
                       nested=False, field=0, smooth=None):
        if isinstance(data, np.ndarray):
            data = (data, self.header['RADESYS'])

        # It's normal for reproject_from_healpix to produce some Numpy invalid
        # value warnings for points that land outside the projection.
        with np.errstate(invalid='ignore'):
            img, mask = reproject_from_healpix(
                data, self.header, hdu_in=hdu_in, order=order, nested=nested,
                field=field)
        img = np.ma.array(img, mask=~mask.astype(bool))

        if smooth is not None:
            # Infrequently used imports
            from astropy.convolution import convolve_fft, Gaussian2DKernel

            pixsize = np.mean(np.abs(self.wcs.wcs.cdelt)) * u.deg
            smooth = (smooth / pixsize).to(u.dimensionless_unscaled).value
            kernel = Gaussian2DKernel(smooth)
            # Ignore divide by zero warnings for pixels that have no valid
            # neighbors.
            with np.errstate(invalid='ignore'):
                img = convolve_fft(img, kernel, fill_value=np.nan)

        return img

    def contour_hpx(self, data, hdu_in=None, order='bilinear', nested=False,
                    field=0, smooth=None, **kwargs):
        """Add contour levels for a HEALPix data set.

        Parameters
        ----------
        data : `numpy.ndarray` or str or `~astropy.io.fits.TableHDU` or `~astropy.io.fits.BinTableHDU` or tuple
            The HEALPix data set. If this is a `numpy.ndarray`, then it is
            interpreted as the HEALPix array in the same coordinate system as
            the axes. Otherwise, the input data can be any type that is
            understood by `reproject.reproject_from_healpix`.
        smooth : `astropy.units.Quantity`, optional
            An optional smoothing length in angle-compatible units.

        Other parameters
        ----------------
        hdu_in, order, nested, field, smooth :
            Extra arguments for `reproject.reproject_from_healpix`
        kwargs :
            Extra keyword arguments for `matplotlib.axes.Axes.contour`

        Returns
        -------
        countours : `matplotlib.contour.QuadContourSet`

        """  # noqa: E501
        img = self._reproject_hpx(data, hdu_in=hdu_in, order=order,
                                  nested=nested, field=field, smooth=smooth)
        return self.contour(img, **kwargs)

    def contourf_hpx(self, data, hdu_in=None, order='bilinear', nested=False,
                     field=0, smooth=None, **kwargs):
        """Add filled contour levels for a HEALPix data set.

        Parameters
        ----------
        data : `numpy.ndarray` or str or `~astropy.io.fits.TableHDU` or `~astropy.io.fits.BinTableHDU` or tuple
            The HEALPix data set. If this is a `numpy.ndarray`, then it is
            interpreted as the HEALPix array in the same coordinate system as
            the axes. Otherwise, the input data can be any type that is
            understood by `reproject.reproject_from_healpix`.
        smooth : `astropy.units.Quantity`, optional
            An optional smoothing length in angle-compatible units.

        Other parameters
        ----------------
        hdu_in, order, nested, field, smooth :
            Extra arguments for `reproject.reproject_from_healpix`
        kwargs :
            Extra keyword arguments for `matplotlib.axes.Axes.contour`

        Returns
        -------
        contours : `matplotlib.contour.QuadContourSet`

        """  # noqa: E501
        img = self._reproject_hpx(data, hdu_in=hdu_in, order=order,
                                  nested=nested, field=field, smooth=smooth)
        return self.contourf(img, **kwargs)

    def imshow_hpx(self, data, hdu_in=None, order='bilinear', nested=False,
                   field=0, smooth=None, **kwargs):
        """Add an image for a HEALPix data set.

        Parameters
        ----------
        data : `numpy.ndarray` or str or `~astropy.io.fits.TableHDU` or `~astropy.io.fits.BinTableHDU` or tuple
            The HEALPix data set. If this is a `numpy.ndarray`, then it is
            interpreted as the HEALPix array in the same coordinate system as
            the axes. Otherwise, the input data can be any type that is
            understood by `reproject.reproject_from_healpix`.
        smooth : `astropy.units.Quantity`, optional
            An optional smoothing length in angle-compatible units.

        Other parameters
        ----------------
        hdu_in, order, nested, field, smooth :
            Extra arguments for `reproject.reproject_from_healpix`
        kwargs :
            Extra keyword arguments for `matplotlib.axes.Axes.contour`

        Returns
        -------
        image : `matplotlib.image.AxesImage`

        """  # noqa: E501
        img = self._reproject_hpx(data, hdu_in=hdu_in, order=order,
                                  nested=nested, field=field, smooth=smooth)
        return self.imshow(img, **kwargs)


class ScaleBar(FancyArrowPatch):

    def _func(self, dx, x, y):
        p1, p2 = self._transAxesToWorld.transform([[x, y], [x + dx, y]])
        p1 = SkyCoord(*p1, unit=u.deg)
        p2 = SkyCoord(*p2, unit=u.deg)
        return np.square((p1.separation(p2) - self._length).value)

    def __init__(self, ax, xy, length, *args, **kwargs):
        x, y = xy
        self._ax = ax
        self._length = u.Quantity(length)
        self._transAxesToWorld = (
            (ax.transAxes - ax.transData) + ax.coords.frame.transform)
        dx = minimize_scalar(
            self._func, args=xy, bounds=[0, 1 - x], method='bounded').x
        custom_kwargs = kwargs
        kwargs = dict(
            capstyle='round',
            color='black',
            linewidth=rcParams['lines.linewidth'],
        )
        kwargs.update(custom_kwargs)
        super().__init__(
            xy, (x + dx, y),
            *args,
            arrowstyle='-',
            shrinkA=0.0,
            shrinkB=0.0,
            transform=ax.transAxes,
            **kwargs)

    def label(self, **kwargs):
        (x0, y), (x1, _) = self._posA_posB
        s = ' {0.value:g}{0.unit:unicode}'.format(self._length)
        return self._ax.text(
            0.5 * (x0 + x1), y, s,
            ha='center', va='bottom', transform=self._ax.transAxes, **kwargs)


class Astro:
    _crval1 = 180
    _xcoord = 'RA--'
    _ycoord = 'DEC-'
    _radesys = 'ICRS'


class GeoAngleFormatterLocator(AngleFormatterLocator):

    def formatter(self, values, spacing):
        return super().formatter(
            reference_angle_deg(values.to(u.deg).value) * u.deg, spacing)


class Geo:
    _crval1 = 0
    _radesys = 'ITRS'
    _xcoord = 'TLON'
    _ycoord = 'TLAT'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.invert_xaxis()
        fl = self.coords[0]._formatter_locator
        self.coords[0]._formatter_locator = GeoAngleFormatterLocator(
            values=fl.values,
            number=fl.number,
            spacing=fl.spacing,
            format=fl.format,
            format_unit=fl.format_unit)


class GalacticAngleFormatterLocator(AngleFormatterLocator):

    def formatter(self, values, spacing):
        return super().formatter(
            wrapped_angle_deg(values.to(u.deg).value) * u.deg, spacing)


class Galactic:
    _crval1 = 0
    _radesys = 'GALACTIC'
    _xcoord = 'GLON'
    _ycoord = 'GLAT'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        fl = self.coords[0]._formatter_locator
        self.coords[0]._formatter_locator = GalacticAngleFormatterLocator(
            values=fl.values,
            number=fl.number,
            spacing=fl.spacing,
            format=fl.format,
            format_unit=fl.format_unit)


class Degrees:
    """WCS axes with longitude axis in degrees."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.coords[0].set_format_unit(u.degree)


class Hours:
    """WCS axes with longitude axis in hour angle."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.coords[0].set_format_unit(u.hourangle)


class Globe(AutoScaledWCSAxes):

    def __init__(self, *args, center='0d 0d', rotate=None, **kwargs):
        center = SkyCoord(
            center, representation_type=UnitSphericalRepresentation).icrs
        header = {
            'NAXIS': 2,
            'NAXIS1': 180,
            'NAXIS2': 180,
            'CRPIX1': 90.5,
            'CRPIX2': 90.5,
            'CRVAL1': center.ra.deg,
            'CRVAL2': center.dec.deg,
            'CDELT1': -2 / np.pi,
            'CDELT2': 2 / np.pi,
            'CTYPE1': self._xcoord + '-SIN',
            'CTYPE2': self._ycoord + '-SIN',
            'RADESYS': self._radesys}
        if rotate is not None:
            header['LONPOLE'] = u.Quantity(rotate).to_value(u.deg)
        super().__init__(
            *args, frame_class=EllipticalFrame, header=header, **kwargs)


class Zoom(AutoScaledWCSAxes):

    def __init__(self, *args, center='0d 0d', radius='1 deg', rotate=None,
                 **kwargs):
        center = SkyCoord(
            center, representation_type=UnitSphericalRepresentation).icrs
        radius = u.Quantity(radius).to(u.deg).value
        header = {
            'NAXIS': 2,
            'NAXIS1': 512,
            'NAXIS2': 512,
            'CRPIX1': 256.5,
            'CRPIX2': 256.5,
            'CRVAL1': center.ra.deg,
            'CRVAL2': center.dec.deg,
            'CDELT1': -radius / 256,
            'CDELT2': radius / 256,
            'CTYPE1': self._xcoord + '-TAN',
            'CTYPE2': self._ycoord + '-TAN',
            'RADESYS': self._radesys}
        if rotate is not None:
            header['LONPOLE'] = u.Quantity(rotate).to_value(u.deg)
        super().__init__(*args, header=header, **kwargs)


class AllSkyAxes(AutoScaledWCSAxes):
    """Base class for a multi-purpose all-sky projection."""

    def __init__(self, *args, center=None, **kwargs):
        if center is None:
            center = f"{self._crval1}d 0d"
        center = SkyCoord(
            center, representation_type=UnitSphericalRepresentation).icrs
        header = {
            'NAXIS': 2,
            'NAXIS1': 360,
            'NAXIS2': 180,
            'CRPIX1': 180.5,
            'CRPIX2': 90.5,
            'CRVAL1': center.ra.deg,
            'CRVAL2': center.dec.deg,
            'CDELT1': -2 * np.sqrt(2) / np.pi,
            'CDELT2': 2 * np.sqrt(2) / np.pi,
            'CTYPE1': self._xcoord + '-' + self._wcsprj,
            'CTYPE2': self._ycoord + '-' + self._wcsprj,
            'RADESYS': self._radesys}
        super().__init__(
            *args, frame_class=EllipticalFrame, header=header, **kwargs)
        self.coords[0].set_ticks(spacing=45 * u.deg)
        self.coords[1].set_ticks(spacing=30 * u.deg)
        self.coords[0].set_ticklabel(exclude_overlapping=True)
        self.coords[1].set_ticklabel(exclude_overlapping=True)


class Aitoff(AllSkyAxes):
    _wcsprj = 'AIT'


class Mollweide(AllSkyAxes):
    _wcsprj = 'MOL'


moddict = globals()

#
# Create subclasses and register all projections:
# '{astro|geo|galactic} {hours|degrees} {aitoff|globe|mollweide|zoom}'
#
bases1 = (Astro, Geo, Galactic)
bases2 = (Hours, Degrees)
bases3 = (Aitoff, Globe, Mollweide, Zoom)
for bases in product(bases1, bases2, bases3):
    class_name = ''.join(cls.__name__ for cls in bases) + 'Axes'
    projection = ' '.join(cls.__name__.lower() for cls in bases)
    new_class = type(class_name, bases, {'name': projection})
    projection_registry.register(new_class)
    moddict[class_name] = new_class
    __all__.append(class_name)

#
# Create some synonyms:
# 'astro' will be short for 'astro hours',
# 'geo' will be short for 'geo degrees'
#

bases2 = (Hours, Degrees, Degrees)
for base1, base2 in zip(bases1, bases2):
    for base3 in (Aitoff, Globe, Mollweide, Zoom):
        bases = (base1, base2, base3)
        orig_class_name = ''.join(cls.__name__ for cls in bases) + 'Axes'
        orig_class = moddict[orig_class_name]
        class_name = ''.join(cls.__name__ for cls in (base1, base3)) + 'Axes'
        projection = ' '.join(cls.__name__.lower() for cls in (base1, base3))
        new_class = type(class_name, (orig_class,), {'name': projection})
        projection_registry.register(new_class)
        moddict[class_name] = new_class
        __all__.append(class_name)

del class_name, moddict, projection, projection_registry, new_class
__all__ = tuple(__all__)

ligo/skymap/plot/angle.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
#
# Copyright (C) 2012-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Angle utilities."""
import numpy as np

__all__ = ('reference_angle', 'reference_angle_deg',
           'wrapped_angle', 'wrapped_angle_deg')


def reference_angle(a):
    """Convert an angle to a reference angle between -pi and pi.

    Examples
    --------

    >>> reference_angle(1.5 * np.pi)
    array(-1.57079633)
    """
    a = np.mod(a, 2 * np.pi)
    return np.where(a <= np.pi, a, a - 2 * np.pi)


def reference_angle_deg(a):
    """Convert an angle to a reference angle between -180 and 180 degrees.

    Examples
    --------

    >>> reference_angle_deg(270.)
    array(-90.)
    """
    a = np.mod(a, 360)
    return np.where(a <= 180, a, a - 360)


def wrapped_angle(a):
    """Convert an angle to a reference angle between 0 and 2*pi.

    Examples
    --------

    >>> wrapped_angle(3 * np.pi)
    3.141592653589793
    """
    return np.mod(a, 2 * np.pi)


def wrapped_angle_deg(a):
    """Convert an angle to a reference angle between 0 and 2*pi.

    Examples
    --------

    >>> wrapped_angle_deg(540.)
    180.0
    """
    return np.mod(a, 360)

ligo/skymap/plot/backdrop.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
#
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Backdrops for astronomical plots."""

from importlib.resources import files
import json
import warnings

from astropy.io import fits
from astropy.time import Time
from astropy.utils.data import download_file
from astropy.wcs import WCS
from matplotlib.image import imread
import numpy as np
from PIL.Image import DecompressionBombWarning
from reproject import reproject_interp

__all__ = ('bluemarble', 'blackmarble', 'coastlines', 'mellinger',
           'reproject_interp_rgb')


def big_imread(*args, **kwargs):
    """Wrapper for imread() that suppresses warnings when loading very large
    images (usually tiffs). Most of the all-sky images that we use in this
    module are large enough to trigger this warning:

        DecompressionBombWarning: Image size (91125000 pixels) exceeds limit of
        89478485 pixels, could be decompression bomb DOS attack.
    """
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', DecompressionBombWarning)
        img = imread(*args, **kwargs)
    return img


def mellinger():
    """Get the Mellinger Milky Way panorama.

    Retrieve, cache, and return the Mellinger Milky Way panorama. See
    http://www.milkywaysky.com.

    Returns
    -------
    `astropy.io.fits.ImageHDU`
        A FITS WCS image in ICRS coordinates.

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from astropy.visualization import (ImageNormalize,
                                           AsymmetricPercentileInterval)
        from astropy.wcs import WCS
        from matplotlib import pyplot as plt
        from ligo.skymap.plot import mellinger
        from reproject import reproject_interp

        ax = plt.axes(projection='astro hours aitoff')
        backdrop = mellinger()
        backdrop_wcs = WCS(backdrop.header).dropaxis(-1)
        interval = AsymmetricPercentileInterval(45, 98)
        norm = ImageNormalize(backdrop.data, interval)
        backdrop_reprojected = np.asarray([
            reproject_interp((layer, backdrop_wcs), ax.header)[0]
            for layer in norm(backdrop.data)])
        backdrop_reprojected = np.rollaxis(backdrop_reprojected, 0, 3)
        ax.imshow(backdrop_reprojected)

    """
    url = 'http://galaxy.phy.cmich.edu/~axel/mwpan2/mwpan2_RGB_3600.fits'
    hdu, = fits.open(url, cache=True)
    return hdu


def bluemarble(t, resolution='low'):
    """Get the "Blue Marble" image.

    Retrieve, cache, and return the NASA/NOAO/NPP "Blue Marble" image showing
    landforms and oceans.

    See https://visibleearth.nasa.gov/view.php?id=74117.

    Parameters
    ----------
    t : `astropy.time.Time`
        Time to embed in the WCS header.
    resolution : {'low', 'high'}
        Specify which version to use: the "low" resolution version (5400x2700
        pixels, the default) or the "high" resolution version (21600x10800
        pixels).

    Returns
    -------
    `astropy.io.fits.ImageHDU`
        A FITS WCS image in ICRS coordinates.

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from matplotlib import pyplot as plt
        from ligo.skymap.plot import bluemarble, reproject_interp_rgb

        obstime = '2017-08-17 12:41:04'
        ax = plt.axes(projection='geo degrees aitoff', obstime=obstime)
        ax.imshow(reproject_interp_rgb(bluemarble(obstime), ax.header))

    """
    variants = {
        'low': '5400x2700',
        'high': '21600x10800'
    }

    url = ('https://eoimages.gsfc.nasa.gov/images/imagerecords/74000/74117/'
           'world.200408.3x{}.png'.format(variants[resolution]))
    img = big_imread(download_file(url, cache=True))
    height, width, ndim = img.shape
    gmst_deg = Time(t).sidereal_time('mean', 'greenwich').deg
    header = fits.Header(dict(
        NAXIS=3,
        NAXIS1=ndim, NAXIS2=width, NAXIS3=height,
        CRPIX2=width / 2, CRPIX3=height / 2,
        CRVAL2=gmst_deg % 360, CRVAL3=0,
        CDELT2=360 / width,
        CDELT3=-180 / height,
        CTYPE2='RA---CAR',
        CTYPE3='DEC--CAR',
        RADESYSa='ICRS').items())
    return fits.ImageHDU(img[:, :, :], header)


def blackmarble(t, resolution='low'):
    """Get the "Black Marble" image.

    Get the NASA/NOAO/NPP image showing city lights, at the sidereal time given
    by t. See https://visibleearth.nasa.gov/view.php?id=79765.

    Parameters
    ----------
    t : `astropy.time.Time`
        Time to embed in the WCS header.
    resolution : {'low', 'mid', 'high'}
        Specify which version to use: the "low" resolution version (3600x1800
        pixels, the default), the "mid" resolution version (13500x6750 pixels),
        or the "high" resolution version (54000x27000 pixels).

    Returns
    -------
    `astropy.io.fits.ImageHDU`
        A FITS WCS image in ICRS coordinates.

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from matplotlib import pyplot as plt
        from ligo.skymap.plot import blackmarble, reproject_interp_rgb

        obstime = '2017-08-17 12:41:04'
        ax = plt.axes(projection='geo degrees aitoff', obstime=obstime)
        ax.imshow(reproject_interp_rgb(blackmarble(obstime), ax.header))

    """
    variants = {
        'low': '3600x1800',
        'high': '13500x6750',
        'mid': '54000x27000'
    }

    url = ('http://eoimages.gsfc.nasa.gov/images/imagerecords/79000/79765/'
           'dnb_land_ocean_ice.2012.{}_geo.tif'.format(variants[resolution]))
    img = big_imread(download_file(url, cache=True))
    height, width, ndim = img.shape
    gmst_deg = Time(t).sidereal_time('mean', 'greenwich').deg
    header = fits.Header(dict(
        NAXIS=3,
        NAXIS1=ndim, NAXIS2=width, NAXIS3=height,
        CRPIX2=width / 2, CRPIX3=height / 2,
        CRVAL2=gmst_deg % 360, CRVAL3=0,
        CDELT2=360 / width,
        CDELT3=-180 / height,
        CTYPE2='RA---CAR',
        CTYPE3='DEC--CAR',
        RADESYSa='ICRS').items())
    return fits.ImageHDU(img[:, :, :], header)


def reproject_interp_rgb(input_data, *args, **kwargs):
    data = input_data.data
    wcs = WCS(input_data.header).celestial
    return np.moveaxis(np.stack([
        reproject_interp((data[:, :, i], wcs),
                         *args, **kwargs)[0].astype(data.dtype)
        for i in range(3)]), 0, -1)


def coastlines():
    with files(__package__).joinpath(
            'ne_simplified_coastline.json').open() as f:
        geoms = json.load(f)['geometries']
    return [coord for geom in geoms for coord in zip(*geom['coordinates'])]

ligo/skymap/plot/bayes_factor.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
#
# Copyright (C) 2019-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Bullet charts for Bayes factors."""

from matplotlib import pyplot as plt
import numpy as np

__all__ = ('plot_bayes_factor',)


def plot_bayes_factor(logb,
                      values=(1, 3, 5),
                      labels=('', 'strong', 'very strong'),
                      xlim=7, title=None, palette='RdYlBu',
                      var_label="B"):
    """Visualize a Bayes factor as a `bullet graph`_.

    Make a bar chart of a log Bayes factor as compared to a set of subjective
    threshold values. By default, use the thresholds from
    Kass & Raftery (1995).

    .. _`bullet graph`: https://en.wikipedia.org/wiki/Bullet_graph

    Parameters
    ----------
    logb : float
        The natural logarithm of the Bayes factor.
    values : list
        A list of floating point values for human-friendly confidence levels.
    labels : list
        A list of string labels for human-friendly confidence levels.
    xlim : float
        Limits of plot (`-xlim` to `+xlim`).
    title : str
        Title for plot.
    palette : str
        Color palette.
    var_label : str
        The variable symbol used in plotting

    Returns
    -------
    fig : Matplotlib figure
    ax : Matplotlib axes

    Examples
    --------

    .. plot::
       :include-source:

        from ligo.skymap.plot.bayes_factor import plot_bayes_factor
        plot_bayes_factor(6.3, title='BAYESTAR is awesome')

    """
    with plt.style.context('seaborn-v0_8-notebook'):
        fig, ax = plt.subplots(figsize=(6, 1.7), tight_layout=True)
        ax.set_xlim(-xlim, xlim)
        ax.set_ylim(-0.5, 0.5)
        ax.set_yticks([])
        ax.set_title(title)
        ax.set_ylabel(r'$\ln\,{}$'.format(var_label), rotation=0,
                      rotation_mode='anchor',
                      ha='right', va='center')

        # Add human-friendly labels
        ticks = (*(-x for x in reversed(values)), 0, *values)
        ticklabels = (
            *(f'{s}\nevidence\nagainst'.strip() for s in reversed(labels)), '',
            *(f'{s}\nevidence\nfor'.strip() for s in labels))
        ax.set_xticks(ticks)
        ax.set_xticklabels(ticklabels)
        plt.setp(ax.get_xticklines(), visible=False)
        plt.setp(ax.get_xticklabels()[:len(ticks) // 2], ha='right')
        plt.setp(ax.get_xticklabels()[len(ticks) // 2:], ha='left')

        # Plot colored bands for confidence thresholds
        fmt = plt.FuncFormatter(lambda x, _: f'{x:+g}'.replace('+0', '0'))
        ax2 = ax.twiny()
        ax2.set_xlim(*ax.get_xlim())
        ax2.set_xticks(ticks)
        ax2.xaxis.set_major_formatter(fmt)
        levels = (-xlim, *ticks, xlim)
        colors = plt.get_cmap(palette)(np.arange(1, len(levels)) / len(levels))
        ax.barh(0, np.diff(levels), 1, levels[:-1],
                linewidth=plt.rcParams['xtick.major.width'],
                color=colors, edgecolor='white')

        # Plot bar for log Bayes factor value
        ax.barh(0, logb, 0.5, color='black',
                linewidth=plt.rcParams['xtick.major.width'],
                edgecolor='white')

        for ax_ in fig.axes:
            ax_.grid(False)
            for spine in ax_.spines.values():
                spine.set_visible(False)

    return fig, ax

ligo/skymap/plot/cmap.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
#
# Copyright (C) 2014-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Register some extra Matplotlib color maps"""

from importlib.resources import files

from matplotlib import colormaps
from matplotlib import colors
import numpy as np

__all__ = ()


for name in ['cylon']:
    # Read in color map RGB data.
    with files(__package__).joinpath(f'{name}.csv').open() as f:
        data = np.loadtxt(f, delimiter=',')

    # Create color map.
    cmap = colors.LinearSegmentedColormap.from_list(name, data)
    # Assign in module.
    locals().update({name: cmap})
    # Register with Matplotlib.
    colormaps.register(cmap=cmap, force=True)

    # Generate reversed color map.
    name += '_r'
    data = data[::-1]
    cmap = colors.LinearSegmentedColormap.from_list(name, data)
    # Assign in module.
    locals().update({name: cmap})
    # Register with Matplotlib.
    colormaps.register(cmap=cmap, force=True)

ligo/skymap/plot/cylon.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
#
# Copyright (C) 2014-2018  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""RGB data for the "Cylon red" color map.

A print- and screen-friendly color map designed specifically for plotting
LSC/Virgo/KAGRA sky maps. The color map is constructed in CIE Lab space,
following a linear ramp in lightness (the `l` coordinate) and a cubic spline in
color components (the `a` and `b` coordinates).

This particular color map was selected from 20 random realizations of this
construction."""

if __name__ == '__main__':  # pragma: no cover
    from colormath.color_conversions import convert_color
    from colormath.color_objects import LabColor, sRGBColor
    from scipy.interpolate import interp1d
    import numpy as np

    def lab_to_rgb(*args):
        """Convert Lab color to sRGB, with components clipped to (0, 1)."""
        Lab = LabColor(*args)
        sRGB = convert_color(Lab, sRGBColor)
        return np.clip(sRGB.get_value_tuple(), 0, 1)

    L_samples = np.linspace(100, 0, 5)

    a_samples = (
        33.34664938,
        98.09940562,
        84.48361516,
        76.62970841,
        21.43276891)

    b_samples = (
        62.73345997,
        2.09003022,
        37.28252236,
        76.22507582,
        16.24862535)

    L = np.linspace(100, 0, 255)
    a = interp1d(L_samples, a_samples[::-1], 'cubic')(L)
    b = interp1d(L_samples, b_samples[::-1], 'cubic')(L)

    for line in __doc__.splitlines():
        print('#', line)
    for L, a, b in zip(L, a, b):
        print(*lab_to_rgb(L, a, b), sep=',')

ligo/skymap/plot/marker.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
#
# Copyright (C) 2016-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Specialized markers."""

from matplotlib.path import Path
import numpy as np

__all__ = ('earth', 'reticle')


earth = Path.unit_circle()
verts = np.concatenate([earth.vertices, [[-1, 0], [1, 0], [0, -1], [0, 1]]])
codes = np.concatenate([earth.codes, [Path.MOVETO, Path.LINETO] * 2])
earth = Path(verts, codes)
del verts, codes
earth.__doc__ = """
The Earth symbol (circle and cross).

Examples
--------
.. plot::
   :context: reset
   :include-source:
   :align: center

    from matplotlib import pyplot as plt
    from ligo.skymap.plot.marker import earth

    plt.plot(0, 0, markersize=20, markeredgewidth=2,
             markerfacecolor='none', marker=earth)

"""


def reticle(inner=0.5, outer=1.0, angle=0.0, which='lrtb'):
    """Create a reticle or crosshairs marker.

    Parameters
    ----------
    inner : float
        Distance from the origin to the inside of the crosshairs.
    outer : float
        Distance from the origin to the outside of the crosshairs.
    angle : float
        Rotation in degrees; 0 for a '+' orientation and 45 for 'x'.

    Returns
    -------
    path : `matplotlib.path.Path`
        The new marker path, suitable for passing to Matplotlib functions
        (e.g., `plt.plot(..., marker=reticle())`)

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from matplotlib import pyplot as plt
        from ligo.skymap.plot.marker import reticle

        markers = [reticle(inner=0),
                   reticle(which='lt'),
                   reticle(which='lt', angle=45)]

        fig, ax = plt.subplots(figsize=(6, 2))
        ax.set_xlim(-0.5, 2.5)
        ax.set_ylim(-0.5, 0.5)
        for x, marker in enumerate(markers):
            ax.plot(x, 0, markersize=20, markeredgewidth=2, marker=marker)

    """
    angle = np.deg2rad(angle)
    x = np.cos(angle)
    y = np.sin(angle)
    rotation = [[x, y], [-y, x]]
    vertdict = {'l': [-1, 0], 'r': [1, 0], 'b': [0, -1], 't': [0, 1]}
    verts = [vertdict[direction] for direction in which]
    codes = [Path.MOVETO, Path.LINETO] * len(verts)
    verts = np.dot(verts, rotation)
    verts = np.swapaxes([inner * verts, outer * verts], 0, 1).reshape(-1, 2)
    return Path(verts, codes)

ligo/skymap/plot/poly.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Plotting tools for drawing polygons."""
import numpy as np
import healpy as hp

from .angle import reference_angle, wrapped_angle

__all__ = ('subdivide_vertices', 'cut_dateline',
           'cut_prime_meridian', 'make_rect_poly')


def subdivide_vertices(vertices, subdivisions):
    """Subdivide a list of vertices by inserting subdivisions additional
    vertices between each original pair of vertices using linear
    interpolation.
    """
    subvertices = np.empty((subdivisions * len(vertices), vertices.shape[1]))
    frac = np.atleast_2d(
        np.arange(subdivisions + 1, dtype=float) / subdivisions).T.repeat(
            vertices.shape[1], 1)
    for i in range(len(vertices)):
        subvertices[i * subdivisions:(i + 1) * subdivisions] = \
            frac[:0:-1, :] * \
            np.expand_dims(vertices[i - 1, :], 0).repeat(subdivisions, 0) + \
            frac[:-1, :] * \
            np.expand_dims(vertices[i, :], 0).repeat(subdivisions, 0)
    return subvertices


def cut_dateline(vertices):
    """Cut a polygon across the dateline, possibly splitting it into multiple
    polygons. Vertices consist of (longitude, latitude) pairs where longitude
    is always given in terms of a reference angle (between -π and π).

    This routine is not meant to cover all possible cases; it will only work
    for convex polygons that extend over less than a hemisphere.
    """
    vertices = vertices.copy()
    vertices[:, 0] += np.pi
    vertices = cut_prime_meridian(vertices)
    for v in vertices:
        v[:, 0] -= np.pi
    return vertices


def cut_prime_meridian(vertices):
    """Cut a polygon across the prime meridian, possibly splitting it into
    multiple polygons. Vertices consist of (longitude, latitude) pairs where
    longitude is always given in terms of a wrapped angle (between 0 and 2π).

    This routine is not meant to cover all possible cases; it will only work
    for convex polygons that extend over less than a hemisphere.
    """
    from shapely import geometry

    # Ensure that the list of vertices does not contain a repeated endpoint.
    if (vertices[0] == vertices[-1]).all():
        vertices = vertices[:-1]

    # Ensure that the longitudes are wrapped from 0 to 2π.
    vertices = np.column_stack((wrapped_angle(vertices[:, 0]), vertices[:, 1]))

    # Test if the segment consisting of points i-1 and i croses the meridian.
    #
    # If the two longitudes are in [0, 2π), then the shortest arc connecting
    # them crosses the meridian if the difference of the angles is greater
    # than π.
    phis = vertices[:, 0]
    phi0, phi1 = np.sort(np.row_stack((np.roll(phis, 1), phis)), axis=0)
    crosses_meridian = (phi1 - phi0 > np.pi)

    # Count the number of times that the polygon crosses the meridian.
    meridian_crossings = np.sum(crosses_meridian)

    if meridian_crossings == 0:
        # There were zero meridian crossings, so we can use the
        # original vertices as is.
        out_vertices = [vertices]
    elif meridian_crossings == 1:
        # There was one meridian crossing, so the polygon encloses the pole.
        # Any meridian-crossing edge has to be extended
        # into a curve following the nearest polar edge of the map.
        i, = np.flatnonzero(crosses_meridian)
        v0 = vertices[i - 1]
        v1 = vertices[i]

        # Find the latitude at which the meridian crossing occurs by
        # linear interpolation.
        delta_lon = abs(reference_angle(v1[0] - v0[0]))
        lat = (abs(reference_angle(v0[0])) / delta_lon * v0[1] +
               abs(reference_angle(v1[0])) / delta_lon * v1[1])

        # FIXME: Use this simple heuristic to decide which pole to enclose.
        sign_lat = np.sign(np.sum(vertices[:, 1]))

        # Find the closer of the left or the right map boundary for
        # each vertex in the line segment.
        lon_0 = 0. if v0[0] < np.pi else 2 * np.pi
        lon_1 = 0. if v1[0] < np.pi else 2 * np.pi

        # Set the output vertices to the polar cap plus the original
        # vertices.
        out_vertices = [
            np.vstack((
                vertices[:i],
                [[lon_0, lat],
                 [lon_0, sign_lat * np.pi / 2],
                 [lon_1, sign_lat * np.pi / 2],
                 [lon_1, lat]],
                vertices[i:]))]
    elif meridian_crossings == 2:
        # Since the polygon is assumed to be convex, if there is an even number
        # of meridian crossings, we know that the polygon does not enclose
        # either pole. Then we can use ordinary Euclidean polygon intersection
        # algorithms.

        out_vertices = []

        # Construct polygon representing map boundaries.
        frame_poly = geometry.Polygon(np.asarray([
            [0., 0.5 * np.pi],
            [0., -0.5 * np.pi],
            [2 * np.pi, -0.5 * np.pi],
            [2 * np.pi, 0.5 * np.pi]]))

        # Intersect with polygon re-wrapped to lie in [-π, π) or [π, 3π).
        for shift in [0, 2 * np.pi]:
            poly = geometry.Polygon(np.column_stack((
                reference_angle(vertices[:, 0]) + shift, vertices[:, 1])))
            intersection = poly.intersection(frame_poly)
            if intersection:
                assert isinstance(intersection, geometry.Polygon)
                assert intersection.is_simple
                out_vertices += [np.asarray(intersection.exterior)]
    else:
        # There were more than two intersections. Not implemented!
        raise NotImplementedError('The polygon intersected the map boundaries '
                                  'two or more times, so it is probably not '
                                  'simple and convex.')

    # Done!
    return out_vertices


def make_rect_poly(width, height, theta, phi, subdivisions=10):
    """Create a Polygon patch representing a rectangle with half-angles width
    and height rotated from the north pole to (theta, phi).
    """
    # Convert width and height to radians, then to Cartesian coordinates.
    w = np.sin(np.deg2rad(width))
    h = np.sin(np.deg2rad(height))

    # Generate vertices of rectangle.
    v = np.asarray([[-w, -h], [w, -h], [w, h], [-w, h]])

    # Subdivide.
    v = subdivide_vertices(v, subdivisions)

    # Project onto sphere by calculating z-coord from normalization condition.
    v = np.hstack((v, np.sqrt(1. - np.expand_dims(np.square(v).sum(1), 1))))

    # Transform vertices.
    v = np.dot(v, hp.rotator.euler_matrix_new(phi, theta, 0, Y=True))

    # Convert to spherical polar coordinates.
    thetas, phis = hp.vec2ang(v)

    # Return list of vertices as longitude, latitude pairs.
    return np.column_stack((wrapped_angle(phis), 0.5 * np.pi - thetas))

ligo/skymap/plot/pp.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Axes subclass for making probability--probability (P--P) plots.

Example
-------
You can create new P--P plot axes by passing the keyword argument
``projection='pp_plot'`` when creating new Matplotlib axes.

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    import numpy as np

    n = 100
    p_values_1 = np.random.uniform(size=n) # One experiment
    p_values_2 = np.random.uniform(size=n) # Another experiment
    p_values_3 = np.random.uniform(size=n) # Yet another experiment

    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection='pp_plot')
    ax.add_confidence_band(n, alpha=0.95) # Add 95% confidence band
    ax.add_diagonal() # Add diagonal line
    ax.add_lightning(n, 20) # Add some random realizations of n samples
    ax.add_series(p_values_1, p_values_2, p_values_3) # Add our data

Or, you can call the constructor of `PPPlot` directly.

.. plot::
   :context: reset
   :include-source:
   :align: center

    from ligo.skymap.plot import PPPlot
    from matplotlib import pyplot as plt
    import numpy as np

    n = 100

    rect = [0.1, 0.1, 0.8, 0.8] # Where to place axes in figure
    fig = plt.figure(figsize=(5, 5))
    ax = PPPlot(fig, rect)
    fig.add_axes(ax)
    ax.add_confidence_band(n, alpha=0.95)
    ax.add_lightning(n, 20)
    ax.add_diagonal()

"""
import matplotlib
from matplotlib import axes
from matplotlib.projections import projection_registry
import scipy.stats
import numpy as np

__all__ = ('PPPlot',)


class PPPlot(axes.Axes):
    """Construct a probability--probability (P--P) plot."""

    name = 'pp_plot'

    def __init__(self, *args, **kwargs):
        # Call parent constructor
        super().__init__(*args, **kwargs)

        # Square axes, limits from 0 to 1
        self.set_aspect(1.0)
        self.set_xlim(0.0, 1.0)
        self.set_ylim(0.0, 1.0)

    @staticmethod
    def _make_series(p_values):
        for ps in p_values:
            if np.ndim(ps) == 1:
                ps = np.sort(np.atleast_1d(ps))
                n = len(ps)
                xs = np.concatenate(([0.], ps, [1.]))
                ys = np.concatenate(([0.], np.arange(1, n + 1) / n, [1.]))
            elif np.ndim(ps) == 2:
                xs = np.concatenate(([0.], ps[0], [1.]))
                ys = np.concatenate(([0.], ps[1], [1.]))
            else:
                raise ValueError('All series must be 1- or 2-dimensional')
            yield xs
            yield ys

    def add_series(self, *p_values, **kwargs):
        """Add a series of P-values to the plot.

        Parameters
        ----------
        p_values : `numpy.ndarray`
            One or more lists of P-values.

            If an entry in the list is one-dimensional, then it is interpreted
            as an unordered list of P-values. The ranked values will be plotted
            on the horizontal axis, and the cumulative fraction will be plotted
            on the vertical axis.

            If an entry in the list is two-dimensional, then the first subarray
            is plotted on the horizontal axis and the second subarray is
            plotted on the vertical axis.

        drawstyle : {'steps', 'lines', 'default'}
            Plotting style. If ``steps``, then plot steps to represent a
            piecewise constant function. If ``lines``, then connect points with
            straight lines. If ``default`` then use steps if there are more
            than 2 pixels per data point, or else lines.

        Other parameters
        ----------------
        kwargs :
            optional extra arguments to `matplotlib.axes.Axes.plot`

        """
        # Construct sequence of x, y pairs to pass to plot()
        args = list(self._make_series(p_values))
        min_n = min(len(ps) for ps in p_values)

        # Make copy of kwargs to pass to plot()
        kwargs = dict(kwargs)
        ds = kwargs.pop('drawstyle', 'default')
        if (ds == 'default' and 2 * min_n > self.bbox.width) or ds == 'lines':
            kwargs['drawstyle'] = 'default'
        else:
            kwargs['drawstyle'] = 'steps-post'

        return self.plot(*args, **kwargs)

    def add_worst(self, *p_values):
        """Mark the point at which the deviation is largest.

        Parameters
        ----------
        p_values : `numpy.ndarray`
            Same as in `add_series`.

        """
        series = list(self._make_series(p_values))
        for xs, ys in zip(series[0::2], series[1::2]):
            i = np.argmax(np.abs(ys - xs))
            x = xs[i]
            y = ys[i]
            if y == x:
                continue
            self.plot([x, x, 0], [0, y, y], '--', color='black', linewidth=0.5)
            if y < x:
                self.plot([x, y], [y, y], '-', color='black', linewidth=1)
                self.text(
                    x, y, ' {0:.02f} '.format(np.around(x - y, 2)),
                    ha='left', va='top')
            else:
                self.plot([x, x], [x, y], '-', color='black', linewidth=1)
                self.text(
                    x, y, ' {0:.02f} '.format(np.around(y - x, 2)),
                    ha='right', va='bottom')

    def add_diagonal(self, *args, **kwargs):
        """Add a diagonal line to the plot, running from (0, 0) to (1, 1).

        Other parameters
        ----------------
        kwargs :
            optional extra arguments to `matplotlib.axes.Axes.plot`

        """
        # Make copy of kwargs to pass to plot()
        kwargs = dict(kwargs)
        kwargs.setdefault('color', 'black')
        kwargs.setdefault('linestyle', 'dashed')
        kwargs.setdefault('linewidth', 0.5)

        # Plot diagonal line
        return self.plot([0, 1], [0, 1], *args, **kwargs)

    def add_lightning(self, nsamples, ntrials, **kwargs):
        """Add P-values drawn from a random uniform distribution, as a visual
        representation of the acceptable scatter about the diagonal.

        Parameters
        ----------
        nsamples : int
            Number of P-values in each trial
        ntrials : int
            Number of line series to draw.

        Other parameters
        ----------------
        kwargs :
            optional extra arguments to `matplotlib.axes.Axes.plot`

        """
        # Draw random samples
        args = np.random.uniform(size=(ntrials, nsamples))

        # Make copy of kwargs to pass to plot()
        kwargs = dict(kwargs)
        kwargs.setdefault('color', 'black')
        kwargs.setdefault('alpha', 0.5)
        kwargs.setdefault('linewidth', 0.25)

        # Plot series
        return self.add_series(*args, **kwargs)

    def add_confidence_band(
            self, nsamples, alpha=0.95, annotate=True, **kwargs):
        """Add a target confidence band.

        Parameters
        ----------
        nsamples : int
            Number of P-values
        alpha : float, default: 0.95
            Confidence level
        annotate : bool, optional, default: True
            If True, then label the confidence band.

        Other parameters
        ----------------
        **kwargs :
            optional extra arguments to `matplotlib.axes.Axes.fill_betweenx`

        """
        n = nsamples
        k = np.arange(0, n + 1)
        p = k / n
        ci_lo, ci_hi = scipy.stats.beta.interval(alpha, k + 1, n - k + 1)

        # Make copy of kwargs to pass to fill_betweenx()
        kwargs = dict(kwargs)
        kwargs.setdefault('color', 'lightgray')
        kwargs.setdefault('edgecolor', 'gray')
        kwargs.setdefault('linewidth', 0.5)
        fontsize = kwargs.pop('fontsize', 'x-small')

        if annotate:
            percent_sign = r'\%' if matplotlib.rcParams['text.usetex'] else '%'
            label = 'target {0:g}{1:s}\nconfidence band'.format(
                100 * alpha, percent_sign)

            self.annotate(
                label,
                xy=(1, 1),
                xytext=(0, 0),
                xycoords='axes fraction',
                textcoords='offset points',
                annotation_clip=False,
                horizontalalignment='right',
                verticalalignment='bottom',
                fontsize=fontsize,
                arrowprops=dict(
                    arrowstyle="->",
                    shrinkA=0, shrinkB=2, linewidth=0.5,
                    connectionstyle="angle,angleA=0,angleB=45,rad=0"))

        return self.fill_betweenx(p, ci_lo, ci_hi, **kwargs)

    @classmethod
    def _as_mpl_axes(cls):
        """Support placement in figure using the `projection` keyword argument.

        See http://matplotlib.org/devel/add_new_projection.html.
        """
        return cls, {}


projection_registry.register(PPPlot)

ligo/skymap/plot/util.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Miscellaneous plotting utilities."""
import matplotlib
from matplotlib import text
from matplotlib import ticker
from matplotlib import patheffects

__all__ = ('colorbar', 'outline_text')


def colorbar(*args):
    from matplotlib import pyplot as plt

    usetex = matplotlib.rcParams['text.usetex']
    locator = ticker.AutoLocator()
    formatter = ticker.ScalarFormatter(useMathText=not usetex)
    formatter.set_scientific(True)
    formatter.set_powerlimits((1e-1, 100))

    # Plot colorbar
    cb = plt.colorbar(*args,
                      orientation='horizontal', shrink=0.4,
                      ticks=locator, format=formatter)

    if cb.orientation == 'vertical':
        axis = cb.ax.yaxis
    else:
        axis = cb.ax.xaxis

    # Move order of magnitude text into last label.
    ticklabels = [label.get_text() for label in axis.get_ticklabels()]
    # Avoid putting two '$' next to each other if we are in tex mode.
    if usetex:
        fmt = '{{{0}}}{{{1}}}'
    else:
        fmt = '{0}{1}'
    ticklabels[-1] = fmt.format(ticklabels[-1], formatter.get_offset())
    axis.set_ticklabels(ticklabels)
    last_ticklabel = axis.get_ticklabels()[-1]
    last_ticklabel.set_horizontalalignment('left')

    # Draw edges in colorbar bands to correct thin white bands that
    # appear in buggy PDF viewers. See:
    # https://github.com/matplotlib/matplotlib/pull/1301
    cb.solids.set_edgecolor("face")

    # Done.
    return cb


def outline_text(ax):
    """Add a white outline to all text to make it stand out from the
    background.
    """
    effects = [patheffects.withStroke(linewidth=2, foreground='w')]
    for artist in ax.findobj(text.Text):
        artist.set_path_effects(effects)

ligo/skymap/postprocess/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/postprocess/contour.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import astropy_healpix as ah
from astropy import units as u
import healpy as hp
import numpy as np

__all__ = ('contour', 'simplify')


def _norm_squared(vertices):
    return np.sum(np.square(vertices), -1)


def _adjacent_triangle_area_squared(vertices):
    return 0.25 * _norm_squared(np.cross(
        np.roll(vertices, -1, axis=0) - vertices,
        np.roll(vertices, +1, axis=0) - vertices))


def _vec2radec(vertices, degrees=False):
    theta, phi = hp.vec2ang(np.asarray(vertices))
    ret = np.column_stack((phi % (2 * np.pi), 0.5 * np.pi - theta))
    if degrees:
        ret = np.rad2deg(ret)
    return ret


def simplify(vertices, min_area):
    """Simplify a polygon on the unit sphere.

    This is a naive, slow implementation of Visvalingam's algorithm (see
    http://bost.ocks.org/mike/simplify/), adapted for for linear rings on a
    sphere.

    Parameters
    ----------
    vertices : `np.ndarray`
        An Nx3 array of Cartesian vertex coordinates. Each vertex should be a
        unit vector.

    min_area : float
        The minimum area of triangles formed by adjacent triplets of vertices.

    Returns
    -------
    vertices : `np.ndarray`

    """
    area_squared = _adjacent_triangle_area_squared(vertices)
    min_area_squared = np.square(min_area)

    while True:
        i_min_area = np.argmin(area_squared)
        if area_squared[i_min_area] > min_area_squared:
            break

        vertices = np.delete(vertices, i_min_area, axis=0)
        area_squared = np.delete(area_squared, i_min_area)
        new_area_squared = _adjacent_triangle_area_squared(vertices)
        area_squared = np.maximum(area_squared, new_area_squared)

    return vertices


# A synonym for ``simplify`` to avoid aliasing by the keyword argument of the
# same name below.
_simplify = simplify


def contour(m, levels, nest=False, degrees=False, simplify=True):
    """Calculate contours from a HEALPix dataset.

    Parameters
    ----------
    m : `numpy.ndarray`
        The HEALPix dataset.
    levels : list
        The list of contour values.
    nest : bool, default=False
        Indicates whether the input sky map is in nested rather than
        ring-indexed HEALPix coordinates (default: ring).
    degrees : bool, default=False
        Whether the contours are in degrees instead of radians.
    simplify : bool, default=True
        Whether to simplify the paths.

    Returns
    -------
    list
        A list with the same length as `levels`.
        Each item is a list of disjoint polygons, of which each item is a
        list of points, of which each is a list consisting of the right
        ascension and declination.

    Examples
    --------
    A very simply example sky map...

    >>> nside = 32
    >>> npix = ah.nside_to_npix(nside)
    >>> ra, dec = hp.pix2ang(nside, np.arange(npix), lonlat=True)
    >>> m = dec
    >>> contour(m, [10, 20, 30], degrees=True)
    [[[[..., ...], ...], ...], ...]

    """
    # Infrequently used import
    import networkx as nx

    # Determine HEALPix resolution.
    npix = len(m)
    nside = ah.npix_to_nside(npix)
    min_area = 0.4 * ah.nside_to_pixel_area(nside).to_value(u.sr)

    neighbors = hp.get_all_neighbours(nside, np.arange(npix), nest=nest).T

    # Loop over the requested contours.
    paths = []
    for level in levels:

        # Find credible region.
        indicator = (m >= level)

        # Find all faces that lie on the boundary.
        # This speeds up the doubly nested ``for`` loop below by allowing us to
        # skip the vast majority of faces that are on the interior or the
        # exterior of the contour.
        tovisit = np.flatnonzero(
            np.any(indicator.reshape(-1, 1) !=
                   indicator[neighbors[:, ::2]], axis=1))

        # Construct a graph of the edges of the contour.
        graph = nx.Graph()
        face_pairs = set()
        for ipix1 in tovisit:
            neighborhood = neighbors[ipix1]
            for _ in range(4):
                neighborhood = np.roll(neighborhood, 2)
                ipix2 = neighborhood[4]

                # Skip this pair of faces if we have already examined it.
                new_face_pair = frozenset((ipix1, ipix2))
                if new_face_pair in face_pairs:
                    continue
                face_pairs.add(new_face_pair)

                # Determine if this pair of faces are on a boundary of the
                # credible level.
                if indicator[ipix1] == indicator[ipix2]:
                    continue

                # Add the common edge of this pair of faces.
                # Label each vertex with the set of faces that they share.
                graph.add_edge(
                    frozenset((ipix1, *neighborhood[2:5])),
                    frozenset((ipix1, *neighborhood[4:7])))
        graph = nx.freeze(graph)

        # Find contours by detecting cycles in the graph.
        cycles = nx.cycle_basis(graph)

        # Construct the coordinates of the vertices by averaging the
        # coordinates of the connected faces.
        cycles = [[
            np.sum(hp.pix2vec(nside, [i for i in v if i != -1], nest=nest), 1)
            for v in cycle] for cycle in cycles]

        # Simplify paths if requested.
        if simplify:
            cycles = [_simplify(cycle, min_area) for cycle in cycles]
            cycles = [cycle for cycle in cycles if len(cycle) > 2]

        # Convert to angles.
        cycles = [
            _vec2radec(cycle, degrees=degrees).tolist() for cycle in cycles]

        # Add to output paths.
        paths.append([cycle + [cycle[0]] for cycle in cycles])

    return paths

ligo/skymap/postprocess/cosmology.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
#
# Copyright (C) 2013-2024  Leo Singer, Rainer Corley
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Cosmology-related utilities.

All functions in this module use the Planck15 cosmological parameters.
"""

import numpy as np
from astropy.cosmology import Planck15 as cosmo, z_at_value
import astropy.units as u


def dDL_dz_for_z(z):
    """Derivative of luminosity distance with respect to redshift."""
    Ok0 = cosmo.Ok0
    DH = cosmo.hubble_distance
    DC_by_DH = (cosmo.comoving_distance(z) / DH).value
    if Ok0 == 0.0:
        ret = 1.0
    elif Ok0 > 0.0:
        ret = np.cosh(np.sqrt(Ok0) * DC_by_DH)
    else:  # Ok0 < 0.0 or Ok0 is nan
        ret = np.cos(np.sqrt(-Ok0) * DC_by_DH)
    ret *= (1 + z) * DH * cosmo.inv_efunc(z)
    ret += cosmo.comoving_transverse_distance(z)
    return ret


def dVC_dVL_for_z(z):
    r"""Ratio, :math:`\mathrm{d}V_C / \mathrm{d}V_L`, between the comoving
    volume element and a naively Euclidean volume element in luminosity
    distance space; given as a function of redshift.

    Given the differential comoving volume per unit redshift,
    :math:`\mathrm{d}V_C / \mathrm{d}z`, and the derivative of luminosity
    distance in terms of redshift, :math:`\mathrm{d}D_L / \mathrm{d}z`, this is
    expressed as:

    .. math::

       \frac{\mathrm{d}V_C}{\mathrm{d}V_L} =
       \frac{\mathrm{d}V_C}{\mathrm{d}z}
       \left(
       {D_L}^2 \frac{\mathrm{d}D_L}{\mathrm{d}z}
       \right)^{-1}.
    """
    Ok0 = cosmo.Ok0
    DH = cosmo.hubble_distance
    DM_by_DH = (cosmo.comoving_transverse_distance(z) / DH).value
    DC_by_DH = (cosmo.comoving_distance(z) / DH).value
    zplus1 = z + 1.0
    if Ok0 == 0.0:
        ret = 1.0
    elif Ok0 > 0.0:
        ret = np.cosh(np.sqrt(Ok0) * DC_by_DH)
    else:  # Ok0 < 0.0 or Ok0 is nan
        ret = np.cos(np.sqrt(-Ok0) * DC_by_DH)
    ret *= zplus1
    ret += DM_by_DH * cosmo.efunc(z)
    ret *= np.square(zplus1)
    return 1.0 / ret


@np.vectorize
def z_for_DL(DL):
    """Redshift as a function of luminosity distance in Mpc."""
    return z_at_value(cosmo.luminosity_distance, DL * u.Mpc).to_value(
        u.dimensionless_unscaled)


def dVC_dVL_for_DL(DL):
    """Same as :meth:`dVC_dVL_for_z`, but as a function of luminosity
    distance.
    """
    return dVC_dVL_for_z(z_for_DL(DL))

ligo/skymap/postprocess/crossmatch.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Catalog cross matching for HEALPix sky maps."""
from collections import namedtuple

import astropy_healpix as ah
from astropy.coordinates import ICRS, SkyCoord, SphericalRepresentation
from astropy import units as u
import healpy as hp
import numpy as np

from .. import distance
from .. import moc

from .cosmology import dVC_dVL_for_DL

__all__ = ('crossmatch', 'CrossmatchResult')


def flood_fill(nside, ipix, m, nest=False):
    """Stack-based flood fill algorithm in HEALPix coordinates.

    Based on <http://en.wikipedia.org/w/index.php?title=Flood_fill&oldid=566525693#Alternative_implementations>.
    """  # noqa: E501
    # Initialize stack with starting pixel index.
    stack = [ipix]
    while stack:
        # Pop last pixel off of the stack.
        ipix = stack.pop()
        # Is this pixel in need of filling?
        if m[ipix]:
            # Fill in this pixel.
            m[ipix] = False
            # Find the pixels neighbors.
            neighbors = hp.get_all_neighbours(nside, ipix, nest=nest)
            # All pixels have up to 8 neighbors. If a pixel has less than 8
            # neighbors, then some entries of the array are set to -1. We
            # have to skip those.
            neighbors = neighbors[neighbors != -1]
            # Push neighboring pixels onto the stack.
            stack.extend(neighbors)


def count_modes(m, nest=False):
    """Count the number of modes in a binary HEALPix image by repeatedly
    applying the flood-fill algorithm.

    WARNING: The input array is clobbered in the process.
    """
    npix = len(m)
    nside = ah.npix_to_nside(npix)
    for nmodes in range(npix):
        nonzeroipix = np.flatnonzero(m)
        if len(nonzeroipix):
            flood_fill(nside, nonzeroipix[0], m, nest=nest)
        else:
            break
    return nmodes


def count_modes_moc(uniq, i):
    n = len(uniq)
    mask = np.concatenate((np.ones(i + 1, dtype=bool),
                           np.zeros(n - i - 1, dtype=bool)))
    sky_map = np.rec.fromarrays((uniq, mask), names=('UNIQ', 'MASK'))
    sky_map = moc.rasterize(sky_map)['MASK']
    return count_modes(sky_map, nest=True)


def cos_angle_distance(theta0, phi0, theta1, phi1):
    """Cosine of angular separation in radians between two points on the
    unit sphere.
    """
    cos_angle_distance = (
        np.cos(phi1 - phi0) * np.sin(theta0) * np.sin(theta1) +
        np.cos(theta0) * np.cos(theta1))
    return np.clip(cos_angle_distance, -1, 1)


def angle_distance(theta0, phi0, theta1, phi1):
    """Angular separation in radians between two points on the unit sphere."""
    return np.arccos(cos_angle_distance(theta0, phi0, theta1, phi1))


# Class to hold return value of find_injection method
CrossmatchResult = namedtuple(
    'CrossmatchResult',
    'searched_area searched_prob offset searched_modes contour_areas '
    'area_probs contour_modes searched_prob_dist contour_dists '
    'searched_vol searched_prob_vol contour_vols probdensity probdensity_vol')
"""Cross match result as returned by
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.

Notes
-----
 - All probabilities returned are between 0 and 1.
 - All angles returned are in degrees.
 - All areas returned are in square degrees.
 - All distances are luminosity distances in units of Mpc.
 - All volumes are in units of Mpc³. If :func:`.crossmatch` was run with
   ``cosmology=False``, then all volumes are Euclidean volumes in luminosity
   distance. If :func:`.crossmatch` was run with ``cosmology=True``, then all
   volumes are comoving volumes.

"""
_same_length_as_coordinates = ''' \
Same length as the `coordinates` argument passed to \
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.'''
_same_length_as_contours = ''' \
of the probabilities specified by the `contour` argument passed to \
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.'''
_same_length_as_areas = ''' \
of the areas specified by the `areas` argument passed to
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.'''
CrossmatchResult.searched_area.__doc__ = '''\
Area within the 2D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.searched_prob.__doc__ = '''\
Probability within the 2D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.offset.__doc__ = '''\
Angles on the sky between the target positions and the maximum a posteriori \
position.''' + _same_length_as_coordinates
CrossmatchResult.searched_modes.__doc__ = '''\
Number of disconnected regions within the 2D credible regions \
containing each target position.''' + _same_length_as_coordinates
CrossmatchResult.contour_areas.__doc__ = '''\
Area within the 2D credible regions''' + _same_length_as_contours
CrossmatchResult.area_probs.__doc__ = '''\
Probability within the 2D credible regions''' + _same_length_as_areas
CrossmatchResult.contour_modes.__doc__ = '''\
Number of disconnected regions within the 2D credible \
regions''' + _same_length_as_contours
CrossmatchResult.searched_prob_dist.__doc__ = '''\
Cumulative CDF of distance, marginalized over sky position, at the distance \
of each of the targets.''' + _same_length_as_coordinates
CrossmatchResult.contour_dists.__doc__ = '''\
Distance credible interval, marginalized over sky \
position,''' + _same_length_as_coordinates
CrossmatchResult.searched_vol.__doc__ = '''\
Volume within the 3D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.searched_prob_vol.__doc__ = '''\
Probability within the 3D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.contour_vols.__doc__ = '''\
Volume within the 3D credible regions''' + _same_length_as_contours
CrossmatchResult.probdensity.__doc__ = '''\
2D probability density per steradian at the positions of each of the \
targets.''' + _same_length_as_coordinates
CrossmatchResult.probdensity_vol.__doc__ = '''\
3D probability density per cubic megaparsec at the positions of each of the \
targets.''' + _same_length_as_coordinates


def crossmatch(sky_map, coordinates=None,
               contours=(), areas=(), modes=False, cosmology=False):
    """Cross match a sky map with a catalog of points.

    Given a sky map and the true right ascension and declination (in radians),
    find the smallest area in deg^2 that would have to be searched to find the
    source, the smallest posterior mass, and the angular offset in degrees from
    the true location to the maximum (mode) of the posterior. Optionally, also
    compute the areas of and numbers of modes within the smallest contours
    containing a given total probability.

    Parameters
    ----------
    sky_map : :class:`astropy.table.Table`
        A multiresolution sky map, as returned by
        :func:`ligo.skymap.io.fits.read_sky_map` called with the keyword
        argument ``moc=True``.

    coordinates : :class:`astropy.coordinates.SkyCoord`, optional
        The catalog of target positions to match against.

    contours : :class:`tuple`, optional
        Credible levels between 0 and 1. If this argument is present, then
        calculate the areas and volumes of the 2D and 3D credible regions that
        contain these probabilities. For example, for ``contours=(0.5, 0.9)``,
        then areas and volumes of the 50% and 90% credible regions.

    areas : :class:`tuple`, optional
        Credible areas in square degrees. If this argument is present, then
        calculate the probability contained in the 2D credible levels that have
        these areas. For example, for ``areas=(20, 100)``, then compute the
        probability within the smallest credible levels of 20 deg² and 100
        deg², respectively.

    modes : :class:`bool`, optional
        If True, then enable calculation of the number of distinct modes or
        islands of probability. Note that this option may be computationally
        expensive.

    cosmology : :class:`bool`, optional
        If True, then search space by descending probability density per unit
        comoving volume. If False, then search space by descending probability
        per luminosity distance cubed.

    Returns
    -------
    result : :class:`~ligo.skymap.postprocess.crossmatch.CrossmatchResult`

    Notes
    -----
    This function is also be used for injection finding; see
    :doc:`/tool/ligo_skymap_stats`.

    Examples
    --------
    First, some imports:

    >>> from astroquery.vizier import VizierClass
    >>> from astropy.coordinates import SkyCoord
    >>> from ligo.skymap.io import read_sky_map
    >>> from ligo.skymap.postprocess import crossmatch

    Next, retrieve the GLADE catalog using Astroquery and get the coordinates
    of all its entries:

    >>> vizier = VizierClass(
    ...     row_limit=-1,
    ...     columns=['recno', 'GWGC', '_RAJ2000', '_DEJ2000', 'Dist'])
    >>> cat, = vizier.get_catalogs('VII/281/glade2')
    >>> cat.sort('recno')  # sort catalog so that doctest output is stable
    >>> del cat['recno']
    >>> coordinates = SkyCoord(cat['_RAJ2000'], cat['_DEJ2000'], cat['Dist'])

    Load the multiresolution sky map for S190814bv:

    >>> url = 'https://gracedb.ligo.org/api/superevents/S190814bv/files/bayestar.multiorder.fits'
    >>> skymap = read_sky_map(url, moc=True)

    Perform the cross match:

    >>> result = crossmatch(skymap, coordinates)

    Using the cross match results, we can list the galaxies within the 90%
    credible volume:

    >>> print(cat[result.searched_prob_vol < 0.9])
         _RAJ2000          _DEJ2000        GWGC            Dist
           deg               deg                           Mpc
    ----------------- ----------------- ---------- --------------------
      9.3396700000000 -19.9342460000000    NGC0171    57.56212553960000
     20.2009090000000 -31.1146050000000        ---   137.16022925600001
      8.9144680000000 -20.1252980000000 ESO540-003    49.07809291930000
     10.6762720000000 -21.7740820000000        ---   276.46938505499998
     13.5855170000000 -23.5523850000000        ---   138.44550704800000
     20.6362970000000 -29.9825150000000        ---   160.23313164900000
     13.1923880000000 -22.9750180000000        ---   236.96795954500001
     11.7813630000000 -24.3706470000000        ---   244.25031189699999
     19.1711120000000 -31.4339490000000        ---   152.13614001400001
     13.6367060000000 -23.4948790000000        ---   141.25162979500001
                  ...               ...        ...                  ...
     11.3517000000000 -25.8597000000000        ---   335.73800000000000
     11.2074000000000 -25.7149000000000        ---   309.02999999999997
     11.1875000000000 -25.7504000000000        ---   295.12099999999998
     10.8609000000000 -25.6904000000000        ---   291.07200000000000
     10.6939000000000 -25.6778300000000        ---   323.59399999999999
     15.4935000000000 -26.0305000000000        ---   304.78899999999999
     15.2794000000000 -27.0411000000000        ---   320.62700000000001
     14.8324000000000 -27.0460000000000        ---   320.62700000000001
     14.5341000000000 -26.0949000000000        ---   307.61000000000001
     23.1281000000000 -31.1109200000000        ---   320.62700000000001
    Length = 1479 rows

    """  # noqa: E501, W291
    # Astropy coordinates that are constructed without distance have
    # a distance field that is unity (dimensionless).
    if coordinates is None:
        true_ra = true_dec = true_dist = None
    else:
        # Ensure that coordinates are in proper frame and representation
        coordinates = SkyCoord(coordinates,
                               representation_type=SphericalRepresentation,
                               frame=ICRS)
        true_ra = coordinates.ra.rad
        true_dec = coordinates.dec.rad
        if np.any(coordinates.distance != 1):
            true_dist = coordinates.distance.to_value(u.Mpc)
        else:
            true_dist = None

    contours = np.asarray(contours)

    # Sort the pixels by descending posterior probability.
    sky_map = np.flipud(np.sort(sky_map, order='PROBDENSITY'))

    # Find the pixel that contains the injection.
    order, ipix = moc.uniq2nest(sky_map['UNIQ'])
    max_order = np.max(order)
    max_nside = ah.level_to_nside(max_order)
    max_ipix = ipix << np.int64(2 * (max_order - order))
    if true_ra is not None:
        true_theta = 0.5 * np.pi - true_dec
        true_phi = true_ra
        true_pix = hp.ang2pix(max_nside, true_theta, true_phi, nest=True)
        i = np.argsort(max_ipix)
        true_idx = i[np.digitize(true_pix, max_ipix[i]) - 1]

    # Find the angular offset between the mode and true locations.
    mode_theta, mode_phi = hp.pix2ang(
        ah.level_to_nside(order[0]), ipix[0], nest=True)
    if true_ra is None:
        offset = np.nan
    else:
        offset = np.rad2deg(
            angle_distance(true_theta, true_phi, mode_theta, mode_phi))

    # Calculate the cumulative area in deg2 and the cumulative probability.
    dA = moc.uniq2pixarea(sky_map['UNIQ'])
    dP = sky_map['PROBDENSITY'] * dA
    prob = np.cumsum(dP)
    area = np.cumsum(dA) * np.square(180 / np.pi)

    if true_ra is None:
        searched_area = searched_prob = probdensity = np.nan
    else:
        # Find the smallest area that would have to be searched to find
        # the true location.
        searched_area = area[true_idx]

        # Find the smallest posterior mass that would have to be searched to
        # find the true location.
        searched_prob = prob[true_idx]

        # Find the probability density.
        probdensity = sky_map['PROBDENSITY'][true_idx]

    # Find the contours of the given credible levels.
    contour_idxs = np.digitize(contours, prob) - 1

    # For each of the given confidence levels, compute the area of the
    # smallest region containing that probability.
    contour_areas = np.interp(
        contours, prob, area, left=0, right=4*180**2/np.pi).tolist()

    # For each listed area, find the probability contained within the
    # smallest credible region of that area.
    area_probs = np.interp(areas, area, prob, left=0, right=1).tolist()

    if modes:
        if true_ra is None:
            searched_modes = np.nan
        else:
            # Count up the number of modes in each of the given contours.
            searched_modes = count_modes_moc(sky_map['UNIQ'], true_idx)
        contour_modes = [
            count_modes_moc(sky_map['UNIQ'], i) for i in contour_idxs]
    else:
        searched_modes = np.nan
        contour_modes = np.nan

    # Distance stats now...
    if 'DISTMU' in sky_map.dtype.names:
        dP_dA = sky_map['PROBDENSITY']
        mu = sky_map['DISTMU']
        sigma = sky_map['DISTSIGMA']
        norm = sky_map['DISTNORM']

        # Set up distance grid.
        n_r = 1000
        distmean, _ = distance.parameters_to_marginal_moments(dP, mu, sigma)
        max_r = 6 * distmean
        if true_dist is not None and np.size(true_dist) != 0 \
                and np.max(true_dist) > max_r:
            max_r = np.max(true_dist)
        d_r = max_r / n_r

        # Calculate searched_prob_dist and contour_dists.
        r = d_r * np.arange(1, n_r)
        P_r = distance.marginal_cdf(r, dP, mu, sigma, norm)
        if true_dist is None:
            searched_prob_dist = np.nan
        else:
            searched_prob_dist = np.interp(true_dist, r, P_r, left=0, right=1)
        if len(contours) == 0:
            contour_dists = []
        else:
            lo, hi = np.interp(
                np.row_stack((
                    0.5 * (1 - contours),
                    0.5 * (1 + contours)
                )), P_r, r, left=0, right=np.inf)
            contour_dists = (hi - lo).tolist()

        # Calculate volume of each voxel, defined as the region within the
        # HEALPix pixel and contained within the two centric spherical shells
        # with radii (r - d_r / 2) and (r + d_r / 2).
        dV = (np.square(r) + np.square(d_r) / 12) * d_r * dA.reshape(-1, 1)

        # Calculate probability within each voxel.
        dP = np.exp(
            -0.5 * np.square(
                (r.reshape(1, -1) - mu.reshape(-1, 1)) / sigma.reshape(-1, 1)
            )
        ) * (dP_dA * norm / (sigma * np.sqrt(2 * np.pi))).reshape(-1, 1) * dV
        dP[np.isnan(dP)] = 0  # Suppress invalid values

        # Calculate probability density per unit volume.

        if cosmology:
            dV *= dVC_dVL_for_DL(r)
        dP_dV = dP / dV
        i = np.flipud(np.argsort(dP_dV.ravel()))

        P_flat = np.cumsum(dP.ravel()[i])
        V_flat = np.cumsum(dV.ravel()[i])

        contour_vols = np.interp(
            contours, P_flat, V_flat, left=0, right=np.inf).tolist()
        P = np.empty_like(P_flat)
        V = np.empty_like(V_flat)
        P[i] = P_flat
        V[i] = V_flat
        P = P.reshape(dP.shape)
        V = V.reshape(dV.shape)
        if true_dist is None:
            searched_vol = searched_prob_vol = probdensity_vol = np.nan
        else:
            i_radec = true_idx
            i_dist = np.digitize(true_dist, r) - 1
            probdensity_vol = dP_dV[i_radec, i_dist]
            searched_prob_vol = P[i_radec, i_dist]
            searched_vol = V[i_radec, i_dist]
    else:
        searched_vol = searched_prob_vol = searched_prob_dist \
            = probdensity_vol = np.nan
        contour_dists = [np.nan] * len(contours)
        contour_vols = [np.nan] * len(contours)

    # Done.
    return CrossmatchResult(
        searched_area, searched_prob, offset, searched_modes, contour_areas,
        area_probs, contour_modes, searched_prob_dist, contour_dists,
        searched_vol, searched_prob_vol, contour_vols, probdensity,
        probdensity_vol)

ligo/skymap/postprocess/ellipse.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
#
# Copyright (C) 2013-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import astropy_healpix as ah
from astropy import units as u
from astropy.wcs import WCS
import healpy as hp
import numpy as np

from .. import moc
from ..extern.numpy.quantile import quantile

__all__ = ('find_ellipse',)


def find_ellipse(prob, cl=90, projection='ARC', nest=False):
    """For a HEALPix map, find an ellipse that contains a given probability.

    The orientation is defined as the angle of the semimajor axis
    counterclockwise from west on the plane of the sky. If you think of the
    semimajor distance as the width of the ellipse, then the orientation is the
    clockwise rotation relative to the image x-axis. Equivalently, the
    orientation is the position angle of the semi-minor axis.

    These conventions match the definitions used in DS9 region files [1]_ and
    Aladin drawing commands [2]_.

    Parameters
    ----------
    prob : np.ndarray, astropy.table.Table
        The HEALPix probability map, either as a full rank explicit array
        or as a multi-order map.
    cl : float, np.ndarray
        The desired credible level or levels (default: 90).
    projection : str, optional
        The WCS projection (default: 'ARC', or zenithal equidistant).
        For a list of possible values, see the Astropy documentation [3]_.
    nest : bool
        HEALPix pixel ordering (default: False, or ring ordering).

    Returns
    -------
    ra : float
        The ellipse center right ascension in degrees.
    dec : float
        The ellipse center right ascension in degrees.
    a : float, np.ndarray
        The lenth of the semimajor axis in degrees.
    b : float, np.ndarray
        The length of the semiminor axis in degrees.
    pa : float
        The orientation of the ellipse axis on the plane of the sky in degrees.
    area : float, np.ndarray
        The area of the ellipse in square degrees.

    Notes
    -----
    The center of the ellipse is the median a posteriori sky position. The
    length and orientation of the semi-major and semi-minor axes are measured
    as follows:

    1. The sky map is transformed to a WCS projection that may be specified by
       the caller. The default projection is ``ARC`` (zenithal equidistant), in
       which radial distances are proportional to the physical angular
       separation from the center point.
    2. A 1-sigma ellipse is estimated by calculating the covariance matrix in
       the projected image plane using three rounds of sigma clipping to reject
       distant outlier points.
    3. The 1-sigma ellipse is inflated until it encloses an integrated
       probability of ``cl`` (default: 90%).

    The function returns a tuple of the right ascension, declination,
    semi-major distance, semi-minor distance, and orientation angle, all in
    degrees.

    If no ellipse can be found that contains integrated probability greater
    than or equal to the desired credible level ``cl``, then the return values
    ``a``, ``b``, and ``area`` will be set to nan.

    References
    ----------
    .. [1] http://ds9.si.edu/doc/ref/region.html
    .. [2] http://aladin.u-strasbg.fr/java/AladinScriptManual.gml#draw
    .. [3] http://docs.astropy.org/en/stable/wcs/index.html#supported-projections

    Examples
    --------
    **Example 1**

    First, we need some imports.

    >>> from astropy.io import fits
    >>> from astropy.utils.data import download_file
    >>> from astropy.wcs import WCS
    >>> import healpy as hp
    >>> from reproject import reproject_from_healpix
    >>> import subprocess

    Next, we download the BAYESTAR sky map for GW170817 from the
    LIGO Document Control Center.

    >>> url = 'https://dcc.ligo.org/public/0146/G1701985/001/bayestar.fits.gz'  # doctest: +SKIP
    >>> filename = download_file(url, cache=True, show_progress=False)  # doctest: +SKIP
    >>> _, healpix_hdu = fits.open(filename)  # doctest: +SKIP
    >>> prob = hp.read_map(healpix_hdu, verbose=False)  # doctest: +SKIP

    Then, we calculate ellipse and write it to a DS9 region file.

    >>> ra, dec, a, b, pa, area = find_ellipse(prob)  # doctest: +SKIP
    >>> print(*np.around([ra, dec, a, b, pa, area], 5))  # doctest: +SKIP
    195.03732 -19.29358 8.66545 1.1793 63.61698 32.07665
    >>> s = 'fk5;ellipse({},{},{},{},{})'.format(ra, dec, a, b, pa)  # doctest: +SKIP
    >>> open('ds9.reg', 'w').write(s)  # doctest: +SKIP

    Then, we reproject a small patch of the HEALPix map, and save it to a file.

    >>> wcs = WCS()  # doctest: +SKIP
    >>> wcs.wcs.ctype = ['RA---ARC', 'DEC--ARC']  # doctest: +SKIP
    >>> wcs.wcs.crval = [ra, dec]  # doctest: +SKIP
    >>> wcs.wcs.crpix = [128, 128]  # doctest: +SKIP
    >>> wcs.wcs.cdelt = [-0.1, 0.1]  # doctest: +SKIP
    >>> img, _ = reproject_from_healpix(healpix_hdu, wcs, [256, 256])  # doctest: +SKIP
    >>> img_hdu = fits.ImageHDU(img, wcs.to_header())  # doctest: +SKIP
    >>> img_hdu.writeto('skymap.fits')  # doctest: +SKIP

    Now open the image and region file in DS9. You should find that the ellipse
    encloses the probability hot spot. You can load the sky map and region file
    from the command line:

    .. code-block:: sh

        $ ds9 skymap.fits -region ds9.reg

    Or you can do this manually:

        1. Open DS9.
        2. Open the sky map: select "File->Open..." and choose ``skymap.fits``
           from the dialog box.
        3. Open the region file: select "Regions->Load Regions..." and choose
           ``ds9.reg`` from the dialog box.

    Now open the image and region file in Aladin.

        1. Open Aladin.
        2. Open the sky map: select "File->Load Local File..." and choose
           ``skymap.fits`` from the dialog box.
        3. Open the sky map: select "File->Load Local File..." and choose
           ``ds9.reg`` from the dialog box.

    You can also compare the original HEALPix file with the ellipse in Aladin:

        1. Open Aladin.
        2. Open the HEALPix file by pasting the URL from the top of this
           example in the Command field at the top of the window and hitting
           return, or by selecting "File->Load Direct URL...", pasting the URL,
           and clicking "Submit."
        3. Open the sky map: select "File->Load Local File..." and choose
           ``ds9.reg`` from the dialog box.

    **Example 2**

    This example shows that we get approximately the same answer for GW171087
    if we read it in as a multi-order map.

    >>> from ..io import read_sky_map  # doctest: +SKIP
    >>> skymap_moc = read_sky_map(healpix_hdu, moc=True)  # doctest: +SKIP
    >>> ellipse = find_ellipse(skymap_moc)  # doctest: +SKIP
    >>> print(*np.around(ellipse, 5))  # doctest: +SKIP
    195.03709 -19.27589 8.67611 1.18167 63.60454 32.08015

    **Example 3**

    I'm not showing the `ra` or `pa` output from the examples below because
    the right ascension is arbitary when dec=90° and the position angle is
    arbitrary when a=b; their arbitrary values may vary depending on your math
    library. Also, I add 0.0 to the outputs because on some platforms you tend
    to get values of dec or pa that get rounded to -0.0, which is within
    numerical precision but would break the doctests (see
    https://stackoverflow.com/questions/11010683).

    This is an example sky map that is uniform in sin(theta) out to a given
    radius in degrees. The 90% credible radius should be 0.9 * radius. (There
    will be deviations for small radius due to finite resolution.)

    >>> def make_uniform_in_sin_theta(radius, nside=512):
    ...     npix = ah.nside_to_npix(nside)
    ...     theta, phi = hp.pix2ang(nside, np.arange(npix))
    ...     theta_max = np.deg2rad(radius)
    ...     prob = np.where(theta <= theta_max, 1 / np.sin(theta), 0)
    ...     return prob / prob.sum()
    ...

    >>> prob = make_uniform_in_sin_theta(1)
    >>> ra, dec, a, b, pa, area = find_ellipse(prob)
    >>> dec, a, b, area  # doctest: +FLOAT_CMP
    (89.90862520480792, 0.8703361458208101, 0.8703357768874356, 2.3788811576269793)

    >>> prob = make_uniform_in_sin_theta(10)
    >>> ra, dec, a, b, pa, area = find_ellipse(prob)
    >>> dec, a, b, area  # doctest: +FLOAT_CMP
    (89.90827657529562, 9.024846562072119, 9.024842703023802, 255.11972196535515)

    >>> prob = make_uniform_in_sin_theta(120)
    >>> ra, dec, a, b, pa, area = find_ellipse(prob)
    >>> dec, a, b, area  # doctest: +FLOAT_CMP
    (90.0, 107.9745037610576, 107.97450376105758, 26988.70467497216)

    **Example 4**

    These are approximately Gaussian distributions.

    >>> from scipy import stats
    >>> def make_gaussian(mean, cov, nside=512):
    ...     npix = ah.nside_to_npix(nside)
    ...     xyz = np.transpose(hp.pix2vec(nside, np.arange(npix)))
    ...     dist = stats.multivariate_normal(mean, cov)
    ...     prob = dist.pdf(xyz)
    ...     return prob / prob.sum()
    ...

    This one is centered at RA=45°, Dec=0° and has a standard deviation of ~1°.

    >>> prob = make_gaussian(
    ...     [1/np.sqrt(2), 1/np.sqrt(2), 0],
    ...     np.square(np.deg2rad(1)))
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (45.0, 0.0, 2.1424077148886744, 2.1420790721225518, 90.0, 14.467701995920123)

    This one is centered at RA=45°, Dec=0°, and is elongated in the north-south
    direction.

    >>> prob = make_gaussian(
    ...     [1/np.sqrt(2), 1/np.sqrt(2), 0],
    ...     np.diag(np.square(np.deg2rad([1, 1, 10]))))
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (44.99999999999999, 0.0, 13.58768882719899, 2.0829846178241853, 90.0, 88.57796576937031)

    This one is centered at RA=0°, Dec=0°, and is elongated in the east-west
    direction.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     np.diag(np.square(np.deg2rad([1, 10, 1]))))
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 13.583918022027149, 2.0823769912401433, 0.0, 88.54622940628761)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    10° to the west of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 0.1, -0.15],
    ...      [0, -0.15, 1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.7713312709293, 33.50754131182681, 80.78231196786838, 6372.344658663038)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    10° to the east of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 0.1, 0.15],
    ...      [0, 0.15, 1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.77133127093047, 33.50754131182745, 99.21768803213159, 6372.344658663096)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    80° to the east of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 1, 0.15],
    ...      [0, 0.15, 0.1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.7756448603915, 33.509863018519894, 170.78252287327365, 6372.425731592412)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    80° to the west of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 1, -0.15],
    ...      [0, -0.15, 0.1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.77564486039148, 33.50986301851987, 9.217477126726322, 6372.42573159241)

    ***Example 5***

    You can ask for other credible levels:
    >>> find_ellipse(prob, cl=50)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 37.054207653285076, 19.168955020015982, 9.217477126726322, 2182.5580135410632)

    Or even for multiple credible levels:
    >>> find_ellipse(prob, cl=[50, 90])  # doctest: +FLOAT_CMP
    (0.0, 0.0, array([37.05420765, 64.77564486]), array([19.16895502, 33.50986302]), 9.217477126726322, array([2182.55801354, 6372.42573159]))
    """  # noqa: E501
    try:
        prob['UNIQ']
    except (IndexError, KeyError, ValueError):
        npix = len(prob)
        nside = ah.npix_to_nside(npix)
        ipix = range(npix)
        area = ah.nside_to_pixel_area(nside).to_value(u.deg**2)
    else:
        order, ipix = moc.uniq2nest(prob['UNIQ'])
        nside = 1 << order.astype(int)
        ipix = ipix.astype(int)
        area = ah.nside_to_pixel_area(nside).to_value(u.sr)
        prob = prob['PROBDENSITY'] * area
        area *= np.square(180 / np.pi)
        nest = True

    # Find median a posteriori sky position.
    xyz0 = [quantile(x, 0.5, weights=prob)
            for x in hp.pix2vec(nside, ipix, nest=nest)]
    (ra,), (dec,) = hp.vec2ang(np.asarray(xyz0), lonlat=True)

    # Construct WCS with the specified projection
    # and centered on mean direction.
    w = WCS()
    w.wcs.crval = [ra, dec]
    w.wcs.ctype = ['RA---' + projection, 'DEC--' + projection]

    # Transform HEALPix to the specified projection.
    xy = w.wcs_world2pix(
        np.transpose(
            hp.pix2ang(
                nside, ipix, nest=nest, lonlat=True)), 1)

    # Keep only values that were inside the projection.
    keep = np.logical_and.reduce(np.isfinite(xy), axis=1)
    xy = xy[keep]
    prob = prob[keep]
    if not np.isscalar(area):
        area = area[keep]

    # Find covariance matrix, performing three rounds of sigma-clipping
    # to reject outliers.
    keep = np.ones(len(xy), dtype=bool)
    for _ in range(3):
        c = np.cov(xy[keep], aweights=prob[keep], rowvar=False)
        nsigmas = np.sqrt(np.sum(xy.T * np.linalg.solve(c, xy.T), axis=0))
        keep &= (nsigmas < 3)

    # Find the number of sigma that enclose the cl% credible level.
    i = np.argsort(nsigmas)
    nsigmas = nsigmas[i]
    cls = np.cumsum(prob[i])
    if np.isscalar(area):
        careas = np.arange(1, len(i) + 1) * area
    else:
        careas = np.cumsum(area[i])
    # np.multiply rather than * to automatically convert to ndarray if needed
    cl = np.multiply(cl, 1e-2)
    nsigma = np.interp(cl, cls, nsigmas, right=np.nan)
    area = np.interp(cl, cls, careas, right=np.nan)

    # Find the eigendecomposition of the covariance matrix.
    w, v = np.linalg.eigh(c)

    # Find the semi-minor and semi-major axes.
    b, a = (nsigma * root_w for root_w in np.sqrt(w))

    # Find the position angle.
    pa = np.rad2deg(np.arctan2(*v[0]))

    # An ellipse is symmetric under rotations of 180°.
    # Return the smallest possible positive position angle.
    pa %= 180

    # Done!
    return ra, dec, a, b, pa, area

ligo/skymap/postprocess/util.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Postprocessing utilities for HEALPix sky maps."""

import astropy_healpix as ah
from astropy.coordinates import (CartesianRepresentation, SkyCoord,
                                 UnitSphericalRepresentation)
from astropy import units as u
import healpy as hp
import numpy as np

__all__ = ('find_greedy_credible_levels', 'smooth_ud_grade', 'posterior_mean',
           'posterior_max')


def find_greedy_credible_levels(p, ranking=None):
    """Find the greedy credible levels of a (possibly multi-dimensional) array.

    Parameters
    ----------
    p : np.ndarray
        The input array, typically a HEALPix image.

    ranking : np.ndarray, optional
        The array to rank in order to determine the greedy order.
        The default is `p` itself.

    Returns
    -------
    cls : np.ndarray
        An array with the same shape as `p`, with values ranging from `0`
        to `p.sum()`, representing the greedy credible level to which each
        entry in the array belongs.

    """
    p = np.asarray(p)
    pflat = p.ravel()
    if ranking is None:
        ranking = pflat
    else:
        ranking = np.ravel(ranking)
    i = np.flipud(np.argsort(ranking))
    cs = np.cumsum(pflat[i])
    cls = np.empty_like(pflat)
    cls[i] = cs
    return cls.reshape(p.shape)


def smooth_ud_grade(m, nside, nest=False):
    """Resample a sky map to a new resolution using bilinear interpolation.

    Parameters
    ----------
    m : np.ndarray
        The input HEALPix array.

    nest : bool, default=False
        Indicates whether the input sky map is in nested rather than
        ring-indexed HEALPix coordinates (default: ring).

    Returns
    -------
    new_m : np.ndarray
        The resampled HEALPix array. The sum of `m` is approximately preserved.

    """
    npix = ah.nside_to_npix(nside)
    theta, phi = hp.pix2ang(nside, np.arange(npix), nest=nest)
    new_m = hp.get_interp_val(m, theta, phi, nest=nest)
    return new_m * len(m) / len(new_m)


def posterior_mean(prob, nest=False):
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    xyz = hp.pix2vec(nside, np.arange(npix), nest=nest)
    mean_xyz = np.average(xyz, axis=1, weights=prob)
    pos = SkyCoord(*mean_xyz, representation_type=CartesianRepresentation)
    pos.representation_type = UnitSphericalRepresentation
    return pos


def posterior_max(prob, nest=False):
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    i = np.argmax(prob)
    return SkyCoord(
        *hp.pix2ang(nside, i, nest=nest, lonlat=True), unit=u.deg)

ligo/skymap/tool/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
#
# Copyright (C) 2013-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Functions that support the command line interface."""

import argparse
import glob
import inspect
import itertools
import logging
import os
import sys

import numpy as np

from ..util import sqlite
from .. import version

version_string = version.__package__ + ' ' + version.version


class FileType(argparse.FileType):
    """Inherit from :class:`argparse.FileType` to enable opening stdin or
    stdout in binary mode.

    This is a workaround for https://bugs.python.org/issue14156.
    """

    def __call__(self, string):
        if string == '-' and 'b' in self._mode:
            if 'r' in self._mode:
                return sys.stdin.buffer
            elif 'w' in self._mode:
                return sys.stdout.buffer
        return super().__call__(string)


class EnableAction(argparse.Action):

    def __init__(self,
                 option_strings,
                 dest,
                 default=True,
                 required=False,
                 help=None):
        opt, = option_strings
        if not opt.startswith('--enable-'):
            raise ValueError('Option string must start with --enable-')
        option_strings = [opt, opt.replace('--enable-', '--disable-')]
        super().__init__(
            option_strings,
            dest=dest,
            nargs=0,
            default=default,
            required=required,
            help=help)

    def __call__(self, parser, namespace, values, option_string):
        if option_string.startswith('--enable-'):
            setattr(namespace, self.dest, True)
        elif option_string.startswith('--disable-'):
            setattr(namespace, self.dest, False)
        else:
            raise RuntimeError('This code cannot be reached')


class GlobAction(argparse._StoreAction):
    """Generate a list of filenames from a list of filenames and globs."""

    def __call__(self, parser, namespace, values, *args, **kwargs):
        values = list(
            itertools.chain.from_iterable(glob.iglob(s) for s in values))
        if values:
            super().__call__(parser, namespace, values, *args, **kwargs)
        nvalues = getattr(namespace, self.dest)
        nvalues = 0 if nvalues is None else len(nvalues)
        if self.nargs == argparse.OPTIONAL:
            if nvalues > 1:
                msg = 'expected at most one file'
            else:
                msg = None
        elif self.nargs == argparse.ONE_OR_MORE:
            if nvalues < 1:
                msg = 'expected at least one file'
            else:
                msg = None
        elif self.nargs == argparse.ZERO_OR_MORE:
            msg = None
        elif int(self.nargs) != nvalues:
            msg = 'expected exactly %s file' % self.nargs
            if self.nargs != 1:
                msg += 's'
        else:
            msg = None
        if msg is not None:
            msg += ', but found '
            msg += '{} file'.format(nvalues)
            if nvalues != 1:
                msg += 's'
            raise argparse.ArgumentError(self, msg)


def get_waveform_parser():
    parser = argparse.ArgumentParser(add_help=False)
    group = parser.add_argument_group(
        'waveform options', 'Options that affect template waveform generation')
    # FIXME: The O1 uberbank high-mass template, SEOBNRv2_ROM_DoubleSpin, does
    # not support frequencies less than 30 Hz.
    group.add_argument(
        '--f-low', type=float, metavar='Hz', default=30,
        help='Low frequency cutoff')
    group.add_argument(
        '--f-high-truncate', type=float, default=0.95,
        help='Truncate waveform at this fraction of the maximum frequency of '
        'the PSD')
    group.add_argument(
        '--waveform', default='o2-uberbank',
        help='Template waveform approximant: e.g., TaylorF2threePointFivePN')
    return parser


def get_posterior_parser():
    parser = argparse.ArgumentParser(add_help=False)
    group = parser.add_argument_group(
        'posterior options', 'Options that affect the BAYESTAR posterior')
    group.add_argument(
        '--min-distance', type=float, metavar='Mpc',
        help='Minimum distance of prior in megaparsecs')
    group.add_argument(
        '--max-distance', type=float, metavar='Mpc',
        help='Maximum distance of prior in megaparsecs')
    group.add_argument(
        '--prior-distance-power', type=int, metavar='-1|2', default=2,
        help='Distance prior: -1 for uniform in log, 2 for uniform in volume')
    group.add_argument(
        '--cosmology', action='store_true',
        help='Use cosmological comoving volume prior')
    group.add_argument(
        '--enable-snr-series', action=EnableAction,
        help='Enable input of SNR time series')
    group.add_argument(
        '--rescale-loglikelihood', type=float, default=0.83,
        help='Rescale log likelihood by the square of this factor to account '
        'for excess technical noise from search pipeline')
    return parser


def get_mcmc_parser():
    parser = argparse.ArgumentParser(add_help=False)
    group = parser.add_argument_group(
        'BAYESTAR MCMC options', 'BAYESTAR options for MCMC sampling')
    group.add_argument(
        '--mcmc', action='store_true',
        help='Use MCMC sampling instead of Gaussian quadrature')
    group.add_argument(
        '--chain-dump', action='store_true',
        help='For MCMC methods, dump the sample chain to disk')
    return parser


class HelpChoicesAction(argparse.Action):

    def __init__(self,
                 option_strings,
                 choices=(),
                 dest=argparse.SUPPRESS,
                 default=argparse.SUPPRESS):
        name = option_strings[0].replace('--help-', '')
        super().__init__(
            option_strings=option_strings,
            dest=dest,
            default=default,
            nargs=0,
            help='show supported values for --' + name + ' and exit')
        self._name = name
        self._choices = choices

    def __call__(self, parser, namespace, values, option_string=None):
        print('Supported values for --' + self._name + ':')
        for choice in self._choices:
            print(choice)
        parser.exit()


def type_with_sideeffect(type):
    def decorator(sideeffect):
        def func(value):
            ret = type(value)
            sideeffect(ret)
            return ret
        return func
    return decorator


@type_with_sideeffect(str)
def loglevel_type(value):
    try:
        value = int(value)
    except ValueError:
        value = value.upper()
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=value)


class LogLevelAction(argparse._StoreAction):

    def __init__(
            self, option_strings, dest, nargs=None, const=None, default=None,
            type=None, choices=None, required=False, help=None, metavar=None):
        metavar = '|'.join(logging._levelToName.values())
        type = loglevel_type
        super().__init__(
            option_strings, dest, nargs=nargs, const=const, default=default,
            type=type, choices=choices, required=required, help=help,
            metavar=metavar)


@type_with_sideeffect(int)
def seed(value):
    np.random.seed(value)


def get_random_parser():
    parser = argparse.ArgumentParser(add_help=False)
    group = parser.add_argument_group(
        'random number generator options',
        'Options that affect the Numpy pseudo-random number genrator')
    group.add_argument(
        '--seed', type=seed, help='Pseudo-random number generator seed '
        '[default: initialized from /dev/urandom or clock]')
    return parser


class HelpFormatter(argparse.RawDescriptionHelpFormatter,
                    argparse.ArgumentDefaultsHelpFormatter):
    pass


class ArgumentParser(argparse.ArgumentParser):
    """An ArgumentParser subclass with some sensible defaults.

    - Any ``.py`` suffix is stripped from the program name, because the
      program is probably being invoked from the stub shell script.

    - The description is taken from the docstring of the file in which the
      ArgumentParser is created.

    - If the description is taken from the docstring, then whitespace in
      the description is preserved.

    - A ``--version`` option is added that prints the version of ligo.skymap.
    """

    def __init__(self,
                 prog=None,
                 usage=None,
                 description=None,
                 epilog=None,
                 parents=[],
                 prefix_chars='-',
                 fromfile_prefix_chars=None,
                 argument_default=None,
                 conflict_handler='error',
                 add_help=True):
        parent_frame = inspect.currentframe().f_back
        if prog is None:
            prog = parent_frame.f_code.co_filename
            prog = os.path.basename(prog)
            prog = prog.replace('_', '-').replace('.py', '')
        if description is None:
            description = parent_frame.f_globals.get('__doc__', None)
        super().__init__(
            prog=prog,
            usage=usage,
            description=description,
            epilog=epilog,
            parents=parents,
            formatter_class=HelpFormatter,
            prefix_chars=prefix_chars,
            fromfile_prefix_chars=fromfile_prefix_chars,
            argument_default=argument_default,
            conflict_handler=conflict_handler,
            add_help=add_help)
        self.register('action', 'glob', GlobAction)
        self.register('action', 'loglevel', LogLevelAction)
        self.add_argument(
            '--version', action='version', version=version_string)
        self.add_argument(
            '-l', '--loglevel', action='loglevel', default='INFO')


class DirType:
    """Factory for directory arguments."""

    def __init__(self, create=False):
        self._create = create

    def __call__(self, string):
        if self._create:
            try:
                os.makedirs(string, exist_ok=True)
            except OSError as e:
                raise argparse.ArgumentTypeError(e.message)
        else:
            try:
                os.listdir(string)
            except OSError as e:
                raise argparse.ArgumentTypeError(e)
        return string


class SQLiteType(FileType):
    """Open an SQLite database, or fail if it does not exist.

    Here is an example of trying to open a file that does not exist for
    reading (mode='r'). It should raise an exception:

    >>> import tempfile
    >>> filetype = SQLiteType('r')
    >>> filename = tempfile.mktemp()
    >>> # Note, simply check or a FileNotFound error in Python 3.
    >>> filetype(filename)
    Traceback (most recent call last):
      ...
    argparse.ArgumentTypeError: ...

    If the file already exists, then it's fine:

    >>> import sqlite3
    >>> filetype = SQLiteType('r')
    >>> with tempfile.NamedTemporaryFile() as f:
    ...     with sqlite3.connect(f.name) as db:
    ...         _ = db.execute('create table foo (bar char)')
    ...     filetype(f.name)
    <sqlite3.Connection object at ...>

    Here is an example of opening a file for writing (mode='w'), which should
    overwrite the file if it exists. Even if the file was not an SQLite
    database beforehand, this should work:

    >>> filetype = SQLiteType('w')
    >>> with tempfile.NamedTemporaryFile(mode='w') as f:
    ...     print('This is definitely not an SQLite file.', file=f)
    ...     f.flush()
    ...     with filetype(f.name) as db:
    ...         db.execute('create table foo (bar char)')
    <sqlite3.Cursor object at ...>

    Here is an example of opening a file for appending (mode='a'), which should
    NOT overwrite the file if it exists. If the file was not an SQLite database
    beforehand, this should raise an exception.

    >>> filetype = SQLiteType('a')
    >>> with tempfile.NamedTemporaryFile(mode='w') as f:
    ...     print('This is definitely not an SQLite file.', file=f)
    ...     f.flush()
    ...     with filetype(f.name) as db:
    ...         db.execute('create table foo (bar char)')
    Traceback (most recent call last):
      ...
    sqlite3.DatabaseError: ...

    And if the database did exist beforehand, then opening for appending
    (mode='a') should not clobber existing tables.

    >>> filetype = SQLiteType('a')
    >>> with tempfile.NamedTemporaryFile() as f:
    ...     with sqlite3.connect(f.name) as db:
    ...         _ = db.execute('create table foo (bar char)')
    ...     with filetype(f.name) as db:
    ...         db.execute('select count(*) from foo').fetchone()
    (0,)
    """

    def __init__(self, mode):
        if mode not in 'arw':
            raise ValueError('Unknown file mode: {}'.format(mode))
        self.mode = mode

    def __call__(self, string):
        try:
            return sqlite.open(string, self.mode)
        except OSError as e:
            raise argparse.ArgumentTypeError(e)


def _sanitize_arg_value_for_xmldoc(value):
    if hasattr(value, 'read'):
        return value.name
    elif isinstance(value, tuple):
        return tuple(_sanitize_arg_value_for_xmldoc(v) for v in value)
    elif isinstance(value, list):
        return [_sanitize_arg_value_for_xmldoc(v) for v in value]
    else:
        return value


def register_to_xmldoc(xmldoc, parser, opts, **kwargs):
    from ligo.lw.utils import process
    params = {key: _sanitize_arg_value_for_xmldoc(value)
              for key, value in opts.__dict__.items()}
    return process.register_to_xmldoc(
        xmldoc, parser.prog, params, **kwargs, version=version_string)


start_msg = '\
Waiting for input on stdin. Type control-D followed by a newline to terminate.'
stop_msg = 'Reached end of file. Exiting.'


def iterlines(file, start_message=start_msg, stop_message=stop_msg):
    """Iterate over non-emtpy lines in a file."""
    is_tty = os.isatty(file.fileno())

    if is_tty:
        print(start_message, file=sys.stderr)

    while True:
        # Read a line.
        line = file.readline()

        if not line:
            # If we reached EOF, then exit.
            break

        # Strip off the trailing newline and any whitespace.
        line = line.strip()

        # Emit the line if it is not empty.
        if line:
            yield line

    if is_tty:
        print(stop_message, file=sys.stderr)


_compress_arg_map = {
    '.bz2': 'bz2',
    '.gz': 'gz',
    '.xz': 'xz',
    '.zst': 'zst'
}


def write_fileobj(xmldoc, f):
    import ligo.lw.utils

    _, ext = os.path.splitext(f.name)
    compress = _compress_arg_map.get(ext.lower())

    with ligo.lw.utils.SignalsTrap():
        ligo.lw.utils.write_fileobj(xmldoc, f, compress=compress)

ligo/skymap/tool/bayestar_inject.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
#
# Copyright (C) 2019-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Rough-cut injection tool.

The idea is to efficiently sample events, uniformly in "sensitive volume"
(differential comoving volume divided by 1 + z), and from a distribution of
masses and spins, such that later detection cuts will not reject an excessive
number of events.

This occurs in two steps. First, we divide the intrinsic parameter space into a
very coarse 10x10x10x10 grid and calculate the maximum horizon distance in each
grid cell. Second, we directly sample injections jointly from the mass and spin
distribution and a uniform and isotropic spatial distribution with a redshift
cutoff that is piecewise constant in the masses and spins.
"""

from functools import partial

from astropy import cosmology
from astropy.cosmology.utils import vectorize_redshift_method
from astropy import units
from astropy.units import dimensionless_unscaled
import numpy as np
from scipy.integrate import quad, fixed_quad
from scipy.interpolate import interp1d
from scipy.optimize import root_scalar
from scipy.ndimage import maximum_filter

from ..util import progress_map
from ..bayestar.filter import sngl_inspiral_psd
from . import (
    ArgumentParser, FileType, get_random_parser, register_to_xmldoc,
    write_fileobj)

try:
    from astropy.cosmology import available as available_cosmologies
except ImportError:
    # FIXME: Remove once we drop support for Astropy < 5.1
    from astropy.cosmology.parameters import available as available_cosmologies


def get_decisive_snr(snrs, min_triggers):
    """Return the SNR for the trigger that decides if an event is detectable.

    Parameters
    ----------
    snrs : list
        List of SNRs (floats).
    min_triggers : int
        Minimum number of triggers to form a coincidence.

    Returns
    -------
    decisive_snr : float

    """
    return sorted(snrs)[-min_triggers]


def lo_hi_nonzero(x):
    nonzero = np.flatnonzero(x)
    return nonzero[0], nonzero[-1]


class GWCosmo:
    """Evaluate GW distance figures of merit for a given cosmology.

    Parameters
    ----------
    cosmo : :class:`astropy.cosmology.FLRW`
        The cosmological model.

    """

    def __init__(self, cosmology):
        self.cosmo = cosmology

    def z_at_snr(self, psds, waveform, f_low, snr_threshold, min_triggers,
                 mass1, mass2, spin1z, spin2z):
        """
        Get redshift at which a waveform attains a given SNR.

        Parameters
        ----------
        psds : list
            List of :class:`lal.REAL8FrequencySeries` objects.
        waveform : str
            Waveform approximant name.
        f_low : float
            Low-frequency cutoff for template.
        snr_threshold : float
            Minimum single-detector SNR.
        min_triggers : int
            Minimum number of triggers to form a coincidence.
        params : list
            List of waveform parameters: mass1, mass2, spin1z, spin2z.

        Returns
        -------
        comoving_distance : float
            Comoving distance in Mpc.

        """
        # Construct waveform
        series = sngl_inspiral_psd(waveform, f_low=f_low,
                                   mass1=mass1, mass2=mass2,
                                   spin1z=spin1z, spin2z=spin2z)
        i_lo, i_hi = lo_hi_nonzero(series.data.data)
        log_f = np.log(series.f0 + series.deltaF * np.arange(i_lo, i_hi + 1))
        log_f_lo = log_f[0]
        log_f_hi = log_f[-1]
        num = interp1d(
            log_f, np.log(series.data.data[i_lo:i_hi + 1]),
            fill_value=-np.inf, bounds_error=False, assume_sorted=True)

        denoms = []
        for series in psds:
            i_lo, i_hi = lo_hi_nonzero(
                np.isfinite(series.data.data) & (series.data.data != 0))
            log_f = np.log(
                series.f0 + series.deltaF * np.arange(i_lo, i_hi + 1))
            denom = interp1d(
                log_f, log_f - np.log(series.data.data[i_lo:i_hi + 1]),
                fill_value=-np.inf, bounds_error=False, assume_sorted=True)
            denoms.append(denom)

        def snr_at_z(z):
            logzp1 = np.log(z + 1)
            integrand = lambda log_f: [
                np.exp(num(log_f + logzp1) + denom(log_f)) for denom in denoms]
            integrals, _ = fixed_quad(
                integrand, log_f_lo, log_f_hi - logzp1, n=1024)
            snr = get_decisive_snr(np.sqrt(4 * integrals), min_triggers)
            with np.errstate(divide='ignore'):
                snr /= self.cosmo.angular_diameter_distance(z).to_value(
                    units.Mpc)
            return snr

        def root_func(z):
            return snr_at_z(z) - snr_threshold

        return root_scalar(root_func, bracket=(0, 1e3)).root

    def get_max_z(self, psds, waveform, f_low, snr_threshold, min_triggers,
                  mass1, mass2, spin1z, spin2z, jobs=1):
        # Calculate the maximum distance on the grid.
        params = [mass1, mass2, spin1z, spin2z]
        shape = np.broadcast_shapes(*(param.shape for param in params))
        result = list(progress_map(
            partial(self.z_at_snr, psds, waveform, f_low,
                    snr_threshold, min_triggers),
            *(param.ravel() for param in params),
            jobs=jobs))
        result = np.reshape(result, shape)

        assert np.all(result >= 0), 'some redshifts are negative'
        assert np.all(np.isfinite(result)), 'some redshifts are not finite'
        return result

    @vectorize_redshift_method
    def _sensitive_volume_integral(self, z):
        dh3_sr = self.cosmo.hubble_distance**3 / units.sr

        def integrand(z):
            result = self.cosmo.differential_comoving_volume(z)
            result /= (1 + z) * dh3_sr
            return result.to_value(dimensionless_unscaled)

        result, _ = quad(integrand, 0, z)
        return result

    def sensitive_volume(self, z):
        """Sensitive volume :math:`V(z)` out to redshift :math:`z`.

        Given a population of events that occur at a constant rate density
        :math:`R` per unit comoving volume per unit proper time, the number of
        observed events out to a redshift :math:`N(z)` over an observation time
        :math:`T` is :math:`N(z) = R T V(z)`.
        """
        dh3 = self.cosmo.hubble_distance**3
        return 4 * np.pi * dh3 * self._sensitive_volume_integral(z)

    def sensitive_distance(self, z):
        r"""Sensitive distance as a function of redshift :math:`z`.

        The sensitive distance is the distance :math:`d_s(z)` defined such that
        :math:`V(z) = 4/3\pi {d_s(z)}^3`, where :math:`V(z)` is the sensitive
        volume.
        """
        dh = self.cosmo.hubble_distance
        return dh * np.cbrt(3 * self._sensitive_volume_integral(z))


def cell_max(values):
    r"""
    Find pairwise max of consecutive elements across all axes of an array.

    Parameters
    ----------
    values : :class:`numpy.ndarray`
        An input array of :math:`n` dimensions,
        :math:`(m_0, m_1, \dots, m_{n-1})`.

    Returns
    -------
    maxima : :class:`numpy.ndarray`
        An input array of :math:`n` dimensions, each with a length 1 less than
        the input array,
        :math:`(m_0 - 1, m_1 - 1, \dots, m_{n-1} - 1)`.

    """
    maxima = maximum_filter(values, size=2, mode='constant')
    indices = (slice(1, None),) * np.ndim(values)
    return maxima[indices]


def assert_not_reached():  # pragma: no cover
    raise AssertionError('This line should not be reached.')


def parser():
    parser = ArgumentParser(parents=[get_random_parser()])
    parser.add_argument(
        '--cosmology', choices=available_cosmologies,
        default='Planck15', help='Cosmological model')

    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument(
        '--distribution', help='Use a preset distribution', choices=(
            'bns_astro', 'bns_broad', 'nsbh_astro', 'nsbh_broad',
            'bbh_astro', 'bbh_broad'))
    group.add_argument(
        '--distribution-samples',
        help='Load samples of the intrinsic mass and spin distribution from '
             'any file that can be read as an Astropy table. The table '
             'columns should be mass1, mass2, spin1z, and spin2z.')

    parser.add_argument(
        '--reference-psd', type=FileType('rb'), metavar='PSD.xml[.gz]',
        required=True, help='PSD file')
    parser.add_argument(
        '--f-low', type=float, default=25.0,
        help='Low frequency cutoff in Hz')
    parser.add_argument(
        '--snr-threshold', type=float, default=4.,
        help='Single-detector SNR threshold')
    parser.add_argument(
        '--min-triggers', type=int, default=2,
        help='Emit coincidences only when at least this many triggers '
        'are found')
    parser.add_argument(
        '--min-snr', type=float,
        help='Minimum decisive SNR of injections given the reference PSDs. '
             'Deprecated; use the synonymous --snr-threshold option instead.')
    parser.add_argument(
        '--max-distance', type=float, metavar='Mpc',
        help='Maximum luminosity distance for injections')
    parser.add_argument(
        '--waveform', default='o2-uberbank',
        help='Waveform approximant')
    parser.add_argument(
        '--nsamples', type=int, default=100000,
        help='Output this many injections')
    parser.add_argument(
        '-o', '--output', type=FileType('wb'), default='-',
        metavar='INJ.xml[.gz]', help='Output file, optionally gzip-compressed')
    parser.add_argument(
        '-j', '--jobs', type=int, default=1, const=None, nargs='?',
        help='Number of threads')
    return parser


def main(args=None):
    import warnings

    from astropy.table import Table
    from ligo.lw import lsctables
    from ligo.lw import utils as ligolw_utils
    from ligo.lw import ligolw
    import lal.series
    from scipy import stats

    p = parser()
    args = p.parse_args(args)

    if args.min_snr is not None:
        warnings.warn(
            'The --min-snr threshold option is deprecated. '
            'Please use the synonymous --snr-threshold option instead.',
            UserWarning)
        args.snr_threshold = args.min_snr

    xmldoc = ligolw.Document()
    xmlroot = xmldoc.appendChild(ligolw.LIGO_LW())
    process = register_to_xmldoc(xmldoc, p, args)

    # Read PSDs
    psds = list(
        lal.series.read_psd_xmldoc(
            ligolw_utils.load_fileobj(
                args.reference_psd,
                contenthandler=lal.series.PSDContentHandler)).values())

    if len(psds) < args.min_triggers:
        parser.error(
            f'The number of PSDs ({len(psds)}) must be greater than or equal '
            f'to the value of --min-triggers ({args.min_triggers}).')

    gwcosmo = GWCosmo(getattr(cosmology, args.cosmology))

    if args.distribution:
        ns_mass_min = 1.0
        ns_mass_max = 2.0
        bh_mass_min = 5.0
        bh_mass_max = 50.0

        ns_astro_spin_min = -0.05
        ns_astro_spin_max = +0.05
        ns_astro_mass_dist = stats.norm(1.33, 0.09)
        ns_astro_spin_dist = stats.uniform(
            ns_astro_spin_min, ns_astro_spin_max - ns_astro_spin_min)

        ns_broad_spin_min = -0.4
        ns_broad_spin_max = +0.4
        ns_broad_mass_dist = stats.uniform(
            ns_mass_min, ns_mass_max - ns_mass_min)
        ns_broad_spin_dist = stats.uniform(
            ns_broad_spin_min, ns_broad_spin_max - ns_broad_spin_min)

        bh_astro_spin_min = -0.99
        bh_astro_spin_max = +0.99
        bh_astro_mass_dist = stats.pareto(b=1.3)
        bh_astro_spin_dist = stats.uniform(
            bh_astro_spin_min, bh_astro_spin_max - bh_astro_spin_min)

        bh_broad_spin_min = -0.99
        bh_broad_spin_max = +0.99
        bh_broad_mass_dist = stats.reciprocal(bh_mass_min, bh_mass_max)
        bh_broad_spin_dist = stats.uniform(
            bh_broad_spin_min, bh_broad_spin_max - bh_broad_spin_min)

        if args.distribution.startswith('bns_'):
            m1_min = m2_min = ns_mass_min
            m1_max = m2_max = ns_mass_max
            if args.distribution.endswith('_astro'):
                x1_min = x2_min = ns_astro_spin_min
                x1_max = x2_max = ns_astro_spin_max
                m1_dist = m2_dist = ns_astro_mass_dist
                x1_dist = x2_dist = ns_astro_spin_dist
            elif args.distribution.endswith('_broad'):
                x1_min = x2_min = ns_broad_spin_min
                x1_max = x2_max = ns_broad_spin_max
                m1_dist = m2_dist = ns_broad_mass_dist
                x1_dist = x2_dist = ns_broad_spin_dist
            else:  # pragma: no cover
                assert_not_reached()
        elif args.distribution.startswith('nsbh_'):
            m1_min = bh_mass_min
            m1_max = bh_mass_max
            m2_min = ns_mass_min
            m2_max = ns_mass_max
            if args.distribution.endswith('_astro'):
                x1_min = bh_astro_spin_min
                x1_max = bh_astro_spin_max
                x2_min = ns_astro_spin_min
                x2_max = ns_astro_spin_max
                m1_dist = bh_astro_mass_dist
                m2_dist = ns_astro_mass_dist
                x1_dist = bh_astro_spin_dist
                x2_dist = ns_astro_spin_dist
            elif args.distribution.endswith('_broad'):
                x1_min = bh_broad_spin_min
                x1_max = bh_broad_spin_max
                x2_min = ns_broad_spin_min
                x2_max = ns_broad_spin_max
                m1_dist = bh_broad_mass_dist
                m2_dist = ns_broad_mass_dist
                x1_dist = bh_broad_spin_dist
                x2_dist = ns_broad_spin_dist
            else:  # pragma: no cover
                assert_not_reached()
        elif args.distribution.startswith('bbh_'):
            m1_min = m2_min = bh_mass_min
            m1_max = m2_max = bh_mass_max
            if args.distribution.endswith('_astro'):
                x1_min = x2_min = bh_astro_spin_min
                x1_max = x2_max = bh_astro_spin_max
                m1_dist = m2_dist = bh_astro_mass_dist
                x1_dist = x2_dist = bh_astro_spin_dist
            elif args.distribution.endswith('_broad'):
                x1_min = x2_min = bh_broad_spin_min
                x1_max = x2_max = bh_broad_spin_max
                m1_dist = m2_dist = bh_broad_mass_dist
                x1_dist = x2_dist = bh_broad_spin_dist
            else:  # pragma: no cover
                assert_not_reached()
        else:  # pragma: no cover
            assert_not_reached()

        dists = (m1_dist, m2_dist, x1_dist, x2_dist)

        # Construct mass1, mass2, spin1z, spin2z grid.
        m1 = np.geomspace(m1_min, m1_max, 10)
        m2 = np.geomspace(m2_min, m2_max, 10)
        x1 = np.linspace(x1_min, x1_max, 10)
        x2 = np.linspace(x2_min, x2_max, 10)
        params = m1, m2, x1, x2

        # Calculate the maximum distance on the grid.
        max_z = gwcosmo.get_max_z(
            psds, args.waveform, args.f_low,
            args.snr_threshold, args.min_triggers,
            *np.meshgrid(m1, m2, x1, x2, indexing='ij'), jobs=args.jobs)
        if args.max_distance is not None:
            new_max_z = cosmology.z_at_value(gwcosmo.cosmo.luminosity_distance,
                                             args.max_distance * units.Mpc)
            max_z[max_z > new_max_z] = new_max_z
        max_distance = gwcosmo.sensitive_distance(max_z).to_value(units.Mpc)

        # Find piecewise constant approximate upper bound on distance.
        max_distance = cell_max(max_distance)

        # Calculate V * T in each grid cell
        cdfs = [dist.cdf(param) for param, dist in zip(params, dists)]
        cdf_los = [cdf[:-1] for cdf in cdfs]
        cdfs = [np.diff(cdf) for cdf in cdfs]
        probs = np.prod(np.meshgrid(*cdfs, indexing='ij'), axis=0)
        probs /= probs.sum()
        probs *= 4/3*np.pi*max_distance**3
        volume = probs.sum()
        probs /= volume
        probs = probs.ravel()

        # Draw random grid cells
        dist = stats.rv_discrete(values=(np.arange(len(probs)), probs))
        indices = np.unravel_index(
            dist.rvs(size=args.nsamples), max_distance.shape)

        # Draw random intrinsic params from each cell
        cols = {}
        cols['mass1'], cols['mass2'], cols['spin1z'], cols['spin2z'] = [
            dist.ppf(stats.uniform(cdf_lo[i], cdf[i]).rvs(size=args.nsamples))
            for i, dist, cdf_lo, cdf in zip(indices, dists, cdf_los, cdfs)]
    elif args.distribution_samples:
        # Load distribution samples.
        samples = Table.read(args.distribution_samples)

        # Calculate the maximum sensitive distance for each sample.
        max_z = gwcosmo.get_max_z(
            psds, args.waveform, args.f_low,
            args.snr_threshold, args.min_triggers,
            samples['mass1'], samples['mass2'],
            samples['spin1z'], samples['spin2z'], jobs=args.jobs)
        if args.max_distance is not None:
            new_max_z = cosmology.z_at_value(gwcosmo.cosmo.luminosity_distance,
                                             args.max_distance * units.Mpc)
            max_z[max_z > new_max_z] = new_max_z
        max_distance = gwcosmo.sensitive_distance(max_z).to_value(units.Mpc)

        # Calculate V * T for each sample.
        probs = 1 / len(max_distance)
        probs *= 4/3*np.pi*max_distance**3
        volume = probs.sum()
        probs /= volume

        # Draw weighted samples for the simulated events.
        dist = stats.rv_discrete(values=(np.arange(len(probs)), probs))
        # Note that we do this in small batches because stats.rv_discrete.rvs
        # has quadratic memory usage, number of values times number of samples,
        # which might cause us to run out of memory if we did it all at once.
        n_batches = max(args.nsamples * len(probs) // 1_000_000_000, 1)
        batch_sizes = [len(subarray) for subarray in
                       np.array_split(np.empty(args.nsamples), n_batches)]
        indices = np.concatenate([dist.rvs(size=batch_size)
                                  for batch_size in batch_sizes])

        cols = {key: samples[key][indices]
                for key in ['mass1', 'mass2', 'spin1z', 'spin2z']}
    else:
        assert_not_reached()

    volumetric_rate = args.nsamples / volume * units.year**-1 * units.Mpc**-3

    # Swap binary components as needed to ensure that mass1 >= mass2.
    # Note that the .copy() is important.
    # See https://github.com/numpy/numpy/issues/14428
    swap = cols['mass1'] < cols['mass2']
    cols['mass1'][swap], cols['mass2'][swap] = \
        cols['mass2'][swap].copy(), cols['mass1'][swap].copy()
    cols['spin1z'][swap], cols['spin2z'][swap] = \
        cols['spin2z'][swap].copy(), cols['spin1z'][swap].copy()

    # Draw random extrinsic parameters
    cols['distance'] = stats.powerlaw(a=3, scale=max_distance[indices]).rvs(
        size=args.nsamples)
    cols['longitude'] = stats.uniform(0, 2 * np.pi).rvs(
        size=args.nsamples)
    cols['latitude'] = np.arcsin(stats.uniform(-1, 2).rvs(
        size=args.nsamples))
    cols['inclination'] = np.arccos(stats.uniform(-1, 2).rvs(
        size=args.nsamples))
    cols['polarization'] = stats.uniform(0, 2 * np.pi).rvs(
        size=args.nsamples)
    cols['coa_phase'] = stats.uniform(-np.pi, 2 * np.pi).rvs(
        size=args.nsamples)
    cols['time_geocent'] = stats.uniform(1e9, units.year.to(units.second)).rvs(
        size=args.nsamples)

    # Convert from sensitive distance to redshift and comoving distance.
    # FIXME: Replace this brute-force lookup table with a solver.
    z = np.linspace(0, max_z.max(), 10000)
    ds = gwcosmo.sensitive_distance(z).to_value(units.Mpc)
    dc = gwcosmo.cosmo.comoving_distance(z).to_value(units.Mpc)
    z_for_ds = interp1d(ds, z, kind='cubic', assume_sorted=True)
    dc_for_ds = interp1d(ds, dc, kind='cubic', assume_sorted=True)
    zp1 = 1 + z_for_ds(cols['distance'])
    cols['distance'] = dc_for_ds(cols['distance'])

    # Apply redshift factor to convert from comoving distance and source frame
    # masses to luminosity distance and observer frame masses.
    for key in ['distance', 'mass1', 'mass2']:
        cols[key] *= zp1

    # Populate sim_inspiral table
    sims = xmlroot.appendChild(lsctables.New(lsctables.SimInspiralTable))
    for row in zip(*cols.values()):
        sims.appendRow(
            **dict(
                dict.fromkeys(sims.validcolumns, None),
                process_id=process.process_id,
                simulation_id=sims.get_next_id(),
                waveform=args.waveform,
                f_lower=args.f_low,
                **dict(zip(cols.keys(), row))))

    # Record process end time.
    process.comment = str(volumetric_rate)
    process.set_end_time_now()

    # Write output file.
    write_fileobj(xmldoc, args.output)

ligo/skymap/tool/bayestar_localize_coincs.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
#
# Copyright (C) 2013-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Produce GW sky maps for all coincidences in search pipeline output database
in LIGO-LW XML, LIGO-LW SQLite, or PyCBC HDF5 format.

The distance prior is controlled by the ``--prior-distance-power`` argument.
If you set ``--prior-distance-power`` to k, then the distance prior is
proportional to r^k. The default is 2, uniform in volume.

If the ``--min-distance`` argument is omitted, it defaults to zero. If the
``--max-distance argument`` is omitted, it defaults to the SNR=4 horizon
distance of the most sensitive detector.

A FITS file is created for each sky map, having a filename of the form
``X.fits`` where X is the integer LIGO-LW row ID of the coinc. The ``OBJECT``
card in the FITS header is also set to the integer row ID.
"""

from . import (
    ArgumentParser, FileType, get_waveform_parser, get_posterior_parser,
    get_mcmc_parser, get_random_parser)


ROW_ID_COMMENT = [
    '',
    'The integer value in the OBJECT card in this FITS header is a row ID',
    'that refers to a coinc_event table row in the input LIGO-LW document.',
    '']


def parser():
    parser = ArgumentParser(
        parents=[get_waveform_parser(), get_posterior_parser(),
                 get_mcmc_parser(), get_random_parser()])
    parser.add_argument(
        '-d', '--disable-detector', metavar='X1', type=str, nargs='+',
        help='disable certain detectors')
    parser.add_argument(
        '--keep-going', '-k', default=False, action='store_true',
        help='Keep processing events if a sky map fails to converge')
    parser.add_argument(
        'input', metavar='INPUT.{hdf,xml,xml.gz,sqlite}', default='-',
        nargs='+', type=FileType('rb'),
        help='Input LIGO-LW XML file, SQLite file, or PyCBC HDF5 files. '
             'For PyCBC, you must supply the coincidence file '
             '(e.g. "H1L1-HDFINJFIND.hdf" or "H1L1-STATMAP.hdf"), '
             'the template bank file (e.g. H1L1-BANK2HDF.hdf), '
             'the single-detector merged PSD files '
             '(e.g. "H1-MERGE_PSDS.hdf" and "L1-MERGE_PSDS.hdf"), '
             'and the single-detector merged trigger files '
             '(e.g. "H1-HDF_TRIGGER_MERGE.hdf" and '
             '"L1-HDF_TRIGGER_MERGE.hdf"), '
             'in any order.')
    parser.add_argument(
        '--pycbc-sample', default='foreground',
        help='(PyCBC only) sample population')
    parser.add_argument(
        '--coinc-event-id', type=int, nargs='*',
        help='run on only these specified events')
    parser.add_argument(
        '--output', '-o', default='.',
        help='output directory')
    parser.add_argument(
        '--condor-submit', action='store_true',
        help='submit to Condor instead of running locally')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    import logging
    log = logging.getLogger('BAYESTAR')

    # BAYESTAR imports.
    from .. import omp
    from ..io import fits, events
    from ..bayestar import localize

    # Other imports.
    import os
    import subprocess
    import sys

    log.info('Using %d OpenMP thread(s)', omp.num_threads)

    # Read coinc file.
    log.info(
        '%s:reading input files', ','.join(file.name for file in opts.input))
    event_source = events.open(*opts.input, sample=opts.pycbc_sample)

    if opts.disable_detector:
        event_source = events.detector_disabled.open(
            event_source, opts.disable_detector)

    os.makedirs(opts.output, exist_ok=True)

    if opts.condor_submit:
        if opts.seed is not None:
            raise NotImplementedError(
                '--seed does not yet work with --condor-submit')
        if opts.coinc_event_id:
            raise ValueError(
                'must not set --coinc-event-id with --condor-submit')
        with subprocess.Popen(['condor_submit'],
                              text=True, stdin=subprocess.PIPE) as proc:
            f = proc.stdin
            print('''
                  accounting_group = ligo.dev.o4.cbc.pe.bayestar
                  on_exit_remove = (ExitBySignal == False) && (ExitCode == 0)
                  on_exit_hold = (ExitBySignal == True) || (ExitCode != 0)
                  on_exit_hold_reason = (ExitBySignal == True \
                    ? strcat("The job exited with signal ", ExitSignal) \
                    : strcat("The job exited with code ", ExitCode))
                  request_memory = 2000 MB
                  request_disk = 100 MB
                  universe = vanilla
                  getenv = true
                  executable = /usr/bin/env
                  JobBatchName = BAYESTAR
                  environment = "OMP_NUM_THREADS=1"
                  ''', file=f)
            print('error =', os.path.join(opts.output, '$(cid).err'), file=f)
            print('arguments = "',
                  *(arg for arg in sys.argv if arg != '--condor-submit'),
                  '--coinc-event-id $(cid)"', file=f)
            print('queue cid in', *event_source, file=f)
        sys.exit(proc.returncode)

    if opts.coinc_event_id:
        event_source = {key: event_source[key] for key in opts.coinc_event_id}

    count_sky_maps_failed = 0

    # Loop over all sngl_inspiral <-> sngl_inspiral coincs.
    for coinc_event_id, event in event_source.items():
        # Loop over sky localization methods
        log.info('%d:computing sky map', coinc_event_id)
        if opts.chain_dump:
            chain_dump = f'{coinc_event_id}.hdf5'
        else:
            chain_dump = None
        try:
            sky_map = localize(
                event, opts.waveform, opts.f_low, opts.min_distance,
                opts.max_distance, opts.prior_distance_power,
                opts.cosmology, mcmc=opts.mcmc, chain_dump=chain_dump,
                enable_snr_series=opts.enable_snr_series,
                f_high_truncate=opts.f_high_truncate,
                rescale_loglikelihood=opts.rescale_loglikelihood)
            sky_map.meta['objid'] = coinc_event_id
            sky_map.meta['comment'] = ROW_ID_COMMENT

        except (ArithmeticError, ValueError):
            log.exception('%d:sky localization failed', coinc_event_id)
            count_sky_maps_failed += 1
            if not opts.keep_going:
                raise
        else:
            log.info('%d:saving sky map', coinc_event_id)
            filename = f'{coinc_event_id}.fits'
            fits.write_sky_map(
                os.path.join(opts.output, filename), sky_map, nest=True)

    if count_sky_maps_failed > 0:
        raise RuntimeError("{0} sky map{1} did not converge".format(
            count_sky_maps_failed, 's' if count_sky_maps_failed > 1 else ''))

ligo/skymap/tool/bayestar_localize_lvalert.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
#
# Copyright (C) 2013-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
r"""
Listen for new events from IGWN Alert and perform sky localization.

`bayestar-localize-lvalert` supports two modes of operation. You can
explicitly specify the GraceDb ID on the command line, as in::

    $ bayestar-localize-lvalert T90713

Or, `bayetar-localize-lvalert` can read GraceDB IDs from stdin (e.g., from the
terminal, or redirected from a fifo)::

    $ mkfifo /var/run/bayestar
    $ tail -F /var/run/bayestar | bayestar_localize_lvalert &
    $ echo T90713 > /var/run/bayestar
"""

from . import (
    ArgumentParser, EnableAction, get_waveform_parser, get_posterior_parser,
    get_mcmc_parser, get_random_parser, iterlines)


def parser():
    parser = ArgumentParser(
        parents=[get_waveform_parser(), get_posterior_parser(),
                 get_mcmc_parser(), get_random_parser()])
    parser.add_argument(
        '-d', '--disable-detector', metavar='X1', type=str, nargs='+',
        help='disable certain detectors')
    parser.add_argument(
        '-N', '--dry-run', action='store_true',
        help='Dry run; do not update GraceDB entry')
    parser.add_argument(
        '--no-tag', action='store_true',
        help='Do not set lvem tag for GraceDB entry')
    parser.add_argument(
        '-o', '--output', metavar='FILE.fits[.gz]', default='bayestar.fits',
        help='Name for uploaded file')
    parser.add_argument(
        '--enable-multiresolution', action=EnableAction, default=True,
        help='generate a multiresolution HEALPix map')
    parser.add_argument(
        'graceid', metavar='G123456', nargs='*',
        help='Run on these GraceDB IDs. If no GraceDB IDs are listed on the '
        'command line, then read newline-separated GraceDB IDs from stdin.')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    import logging
    import os
    import re
    import sys
    import tempfile
    import urllib.parse
    from ..bayestar import localize, rasterize
    from ..io import fits
    from ..io import events
    from .. import omp
    from ..util.file import rename
    import ligo.gracedb.logging
    import ligo.gracedb.rest

    log = logging.getLogger('BAYESTAR')

    log.info('Using %d OpenMP thread(s)', omp.num_threads)

    # If no GraceDB IDs were specified on the command line, then read them
    # from stdin line-by-line.
    graceids = opts.graceid if opts.graceid else iterlines(sys.stdin)

    # Fire up a GraceDb client
    # FIXME: Mimic the behavior of the GraceDb command line client, where the
    # environment variable GRACEDB_SERVICE_URL overrides the default service
    # URL. It would be nice to get this behavior into the gracedb package
    # itself.
    gracedb = ligo.gracedb.rest.GraceDb(
        os.environ.get(
            'GRACEDB_SERVICE_URL', ligo.gracedb.rest.DEFAULT_SERVICE_URL))

    # Determine the base URL for event pages.
    scheme, netloc, *_ = urllib.parse.urlparse(gracedb._service_url)
    base_url = urllib.parse.urlunparse((scheme, netloc, 'events', '', '', ''))

    if opts.chain_dump:
        chain_dump = re.sub(r'.fits(.gz)?$', r'.hdf5', opts.output)
    else:
        chain_dump = None

    tags = ("sky_loc",)
    if not opts.no_tag:
        tags += ("lvem",)

    event_source = events.gracedb.open(graceids, gracedb)

    if opts.disable_detector:
        event_source = events.detector_disabled.open(
            event_source, opts.disable_detector)

    for graceid in event_source.keys():

        try:
            event = event_source[graceid]
        except:  # noqa: E722
            log.exception('failed to read event %s from GraceDB', graceid)
            continue

        # Send log messages to GraceDb too
        if not opts.dry_run:
            handler = ligo.gracedb.logging.GraceDbLogHandler(gracedb, graceid)
            handler.setLevel(logging.INFO)
            logging.root.addHandler(handler)

        # A little bit of Cylon humor
        log.info('by your command...')

        try:
            # perform sky localization
            log.info("starting sky localization")
            sky_map = localize(
                event, opts.waveform, opts.f_low, opts.min_distance,
                opts.max_distance, opts.prior_distance_power, opts.cosmology,
                mcmc=opts.mcmc, chain_dump=chain_dump,
                enable_snr_series=opts.enable_snr_series,
                f_high_truncate=opts.f_high_truncate,
                rescale_loglikelihood=opts.rescale_loglikelihood)
            if not opts.enable_multiresolution:
                sky_map = rasterize(sky_map)
            sky_map.meta['objid'] = str(graceid)
            sky_map.meta['url'] = '{}/{}'.format(base_url, graceid)
            log.info("sky localization complete")

            # upload FITS file
            with tempfile.TemporaryDirectory() as fitsdir:
                fitspath = os.path.join(fitsdir, opts.output)
                fits.write_sky_map(fitspath, sky_map, nest=True)
                log.debug('wrote FITS file: %s', opts.output)
                if opts.dry_run:
                    rename(fitspath, os.path.join('.', opts.output))
                else:
                    gracedb.writeLog(
                        graceid, "BAYESTAR rapid sky localization ready",
                        filename=fitspath, tagname=tags)
                log.debug('uploaded FITS file')
        except KeyboardInterrupt:
            # Produce log message and then exit if we receive SIGINT (ctrl-C).
            log.exception("sky localization failed")
            raise
        except:  # noqa: E722
            # Produce log message for any otherwise uncaught exception.
            # Unless we are in dry-run mode, keep going.
            log.exception("sky localization failed")
            if opts.dry_run:
                # Then re-raise the exception if we are in dry-run mode
                raise

        if not opts.dry_run:
            # Remove old log handler
            logging.root.removeHandler(handler)
            del handler

ligo/skymap/tool/bayestar_mcmc.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
#
# Copyright (C) 2013-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Markov-Chain Monte Carlo sky localization."""

from . import (
    ArgumentParser, FileType, get_waveform_parser,
    get_posterior_parser, get_random_parser)


def parser():
    parser = ArgumentParser(parents=[get_waveform_parser(),
                                     get_posterior_parser(),
                                     get_random_parser()])
    parser.add_argument(
        'input', metavar='INPUT.{hdf,xml,xml.gz,sqlite}', default='-',
        nargs='+', type=FileType('rb'),
        help='Input LIGO-LW XML file, SQLite file, or PyCBC HDF5 files. '
             'For PyCBC, you must supply the coincidence file '
             '(e.g. "H1L1-HDFINJFIND.hdf" or "H1L1-STATMAP.hdf"), '
             'the template bank file (e.g. H1L1-BANK2HDF.hdf), '
             'the single-detector merged PSD files '
             '(e.g. "H1-MERGE_PSDS.hdf" and "L1-MERGE_PSDS.hdf"), '
             'and the single-detector merged trigger files '
             '(e.g. "H1-HDF_TRIGGER_MERGE.hdf" and '
             '"L1-HDF_TRIGGER_MERGE.hdf"), '
             'in any order.')
    parser.add_argument(
        '--pycbc-sample', default='foreground',
        help='(PyCBC only) sample population')
    parser.add_argument(
        '--coinc-event-id', type=int, nargs='*',
        help='run on only these specified events')
    parser.add_argument(
        '--output', '-o', default='.',
        help='output directory')
    parser.add_argument(
        '--condor-submit', action='store_true',
        help='submit to Condor instead of running locally')

    group = parser.add_argument_group(
        'fixed parameter options',
        'Options to hold certain parameters constant')
    group.add_argument('--ra', type=float, metavar='DEG',
                       help='Right ascension')
    group.add_argument('--dec', type=float, metavar='DEG',
                       help='Declination')
    group.add_argument('--distance', type=float, metavar='Mpc',
                       help='Luminosity distance')

    return parser


def identity(x):
    return x


def main(args=None):
    opts = parser().parse_args(args)

    import logging
    log = logging.getLogger('BAYESTAR')

    # BAYESTAR imports.
    from ..io import events, hdf5
    from ..bayestar import condition, condition_prior, ez_emcee, log_post

    # Other imports.
    from astropy.table import Table
    import numpy as np
    import os
    import subprocess
    import sys

    import lal

    # Read coinc file.
    log.info(
        '%s:reading input files', ','.join(file.name for file in opts.input))
    event_source = events.open(*opts.input, sample=opts.pycbc_sample)

    os.makedirs(opts.output, exist_ok=True)

    if opts.condor_submit:
        if opts.seed is not None:
            raise NotImplementedError(
                '--seed does not yet work with --condor-submit')
        if opts.coinc_event_id:
            raise ValueError(
                'must not set --coinc-event-id with --condor-submit')
        with subprocess.Popen(['condor_submit'],
                              text=True, stdin=subprocess.PIPE) as proc:
            f = proc.stdin
            print('''
                  accounting_group = ligo.dev.o4.cbc.pe.bayestar
                  on_exit_remove = (ExitBySignal == False) && (ExitCode == 0)
                  on_exit_hold = (ExitBySignal == True) || (ExitCode != 0)
                  on_exit_hold_reason = (ExitBySignal == True \
                    ? strcat("The job exited with signal ", ExitSignal) \
                    : strcat("The job exited with code ", ExitCode))
                  request_memory = 1000 MB
                  universe = vanilla
                  getenv = true
                  executable = /usr/bin/env
                  JobBatchName = BAYESTAR
                  environment = "OMP_NUM_THREADS=1"
                  ''', file=f)
            print('error =', os.path.join(opts.output, '$(cid).err'), file=f)
            print('log =', os.path.join(opts.output, '$(cid).log'), file=f)
            print('arguments = "',
                  *(arg for arg in sys.argv if arg != '--condor-submit'),
                  '--coinc-event-id $(cid)"', file=f)
            print('queue cid in', *event_source, file=f)
        sys.exit(proc.returncode)

    if opts.coinc_event_id:
        event_source = {key: event_source[key] for key in opts.coinc_event_id}

    # Loop over all sngl_inspiral <-> sngl_inspiral coincs.
    for int_coinc_event_id, event in event_source.items():
        coinc_event_id = 'coinc_event:coinc_event_id:{}'.format(
            int_coinc_event_id)

        log.info('%s:preparing', coinc_event_id)

        epoch, sample_rate, epochs, snrs, responses, locations, horizons = \
            condition(event, waveform=opts.waveform, f_low=opts.f_low,
                      enable_snr_series=opts.enable_snr_series,
                      f_high_truncate=opts.f_high_truncate)

        min_distance, max_distance, prior_distance_power, cosmology = \
            condition_prior(horizons, opts.min_distance, opts.max_distance,
                            opts.prior_distance_power, opts.cosmology)

        gmst = lal.GreenwichMeanSiderealTime(epoch)

        max_abs_t = 2 * snrs.data.shape[1] / sample_rate
        xmin = [0, -1, min_distance, -1, 0, 0]
        xmax = [2 * np.pi, 1, max_distance, 1, 2 * np.pi, 2 * max_abs_t]
        names = 'ra dec distance inclination twopsi time'.split()
        transformed_names = 'ra sin_dec distance u twopsi time'.split()
        forward_transforms = [identity, np.sin, identity,
                              np.cos, identity, identity]
        reverse_transforms = [identity, np.arcsin, identity,
                              np.arccos, identity, identity]
        kwargs = dict(min_distance=min_distance, max_distance=max_distance,
                      prior_distance_power=prior_distance_power,
                      cosmology=cosmology, gmst=gmst, sample_rate=sample_rate,
                      epochs=epochs, snrs=snrs, responses=responses,
                      locations=locations, horizons=horizons)

        # Fix parameters
        for i, key in reversed(list(enumerate(['ra', 'dec', 'distance']))):
            value = getattr(opts, key)
            if value is None:
                continue

            if key in ['ra', 'dec']:
                # FIXME: figure out a more elegant way to address different
                # units in command line arguments and posterior samples
                value = np.deg2rad(value)

            kwargs[transformed_names[i]] = forward_transforms[i](value)
            del (xmin[i], xmax[i], names[i], transformed_names[i],
                 forward_transforms[i], reverse_transforms[i])

        log.info('%s:sampling', coinc_event_id)

        # Run MCMC
        chain = ez_emcee(log_post, xmin, xmax, kwargs=kwargs, vectorize=True)

        # Transform back from sin_dec to dec and cos_inclination to inclination
        for i, func in enumerate(reverse_transforms):
            chain[:, i] = func(chain[:, i])

        # Create Astropy table
        chain = Table(rows=chain, names=names, copy=False)

        log.info('%s:saving posterior samples', coinc_event_id)

        hdf5.write_samples(
            chain,
            os.path.join(opts.output, '{}.hdf5'.format(int_coinc_event_id)),
            path='/bayestar/posterior_samples', overwrite=True)

ligo/skymap/tool/bayestar_realize_coincs.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
#
# Copyright (C) 2013-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Synthesize triggers for simulated sources by realizing Gaussian measurement
errors in SNR and time of arrival. The input file (or stdin if the input file
is omitted) should be an optionally gzip-compressed LIGO-LW XML file of the
form produced by lalapps_inspinj. The output file (or stdout if omitted) will
be an optionally gzip-compressed LIGO-LW XML file containing single-detector
triggers and coincidences.

The root-mean square measurement error depends on the SNR of the signal, so
there is a choice for how to generate perturbed time and phase measurements:

 - `zero-noise`: no measurement error at all
 - `gaussian-noise`: measurement error for a matched filter in Gaussian noise
"""

import copy
import functools

import lal
import numpy as np

from . import (
    ArgumentParser, EnableAction, FileType, get_random_parser,
    register_to_xmldoc, write_fileobj)


def parser():
    # Determine list of known detectors for command line arguments.
    available_ifos = sorted(det.frDetector.prefix
                            for det in lal.CachedDetectors)

    parser = ArgumentParser(parents=[get_random_parser()])
    parser.add_argument(
        'input', metavar='IN.xml[.gz]', type=FileType('rb'),
        default='-', help='Name of input file')
    parser.add_argument(
        '-o', '--output', metavar='OUT.xml[.gz]', type=FileType('wb'),
        default='-', help='Name of output file')
    parser.add_argument(
        '-j', '--jobs', type=int, default=1, const=None, nargs='?',
        help='Number of threads')
    parser.add_argument(
        '--detector', metavar='|'.join(available_ifos), nargs='+',
        help='Detectors to use', choices=available_ifos, required=True)
    parser.add_argument(
        '--waveform',
        help='Waveform to use for injections (overrides values in '
        'sim_inspiral table)')
    parser.add_argument(
        '--snr-threshold', type=float, default=4.,
        help='Single-detector SNR threshold')
    parser.add_argument(
        '--net-snr-threshold', type=float, default=12.,
        help='Network SNR threshold')
    parser.add_argument(
        '--keep-subthreshold', action='store_true',
        help='Keep sub-threshold triggers that do not contribute to the '
        'network SNR')
    parser.add_argument(
        '--min-triggers', type=int, default=2,
        help='Emit coincidences only when at least this many triggers '
        'are found')
    parser.add_argument(
        '--min-distance', type=float, default=0.0,
        help='Skip events with distance less than this value')
    parser.add_argument(
        '--max-distance', type=float, default=float('inf'),
        help='Skip events with distance greater than this value')
    parser.add_argument(
        '--measurement-error', default='zero-noise',
        choices=('zero-noise', 'gaussian-noise'),
        help='How to compute the measurement error')
    parser.add_argument(
        '--enable-snr-series', action=EnableAction, default=True,
        help='Enable output of SNR time series')
    parser.add_argument(
        '--reference-psd', metavar='PSD.xml[.gz]', type=FileType('rb'),
        required=True, help='Name of PSD file')
    parser.add_argument(
        '--f-low', type=float,
        help='Override low frequency cutoff found in sim_inspiral table')
    parser.add_argument(
        '--f-high', type=float,
        help='Set high frequency cutoff to simulate early warning')
    parser.add_argument(
        '--duty-cycle', type=float, default=1.0,
        help='Single-detector duty cycle')
    parser.add_argument(
        '-P', '--preserve-ids', action='store_true',
        help='Preserve original simulation_ids')
    return parser


def simulate_snr(ra, dec, psi, inc, distance, epoch, gmst, H, S,
                 response, location, measurement_error='zero-noise',
                 duration=0.1):
    from scipy.interpolate import interp1d

    from ..bayestar import filter
    from ..bayestar.interpolation import interpolate_max

    # Calculate whitened template autocorrelation sequence.
    HS = filter.signal_psd_series(H, S)
    n = len(HS.data.data)
    acor, sample_rate = filter.autocorrelation(HS, duration)

    # Calculate time, amplitude, and phase.
    u = np.cos(inc)
    u2 = np.square(u)
    signal_model = filter.SignalModel(HS)
    horizon = signal_model.get_horizon_distance()
    Fplus, Fcross = lal.ComputeDetAMResponse(response, ra, dec, psi, gmst)
    toa = lal.TimeDelayFromEarthCenter(location, ra, dec, epoch)
    z = (0.5 * (1 + u2) * Fplus + 1j * u * Fcross) * horizon / distance

    # Calculate complex autocorrelation sequence.
    snr_series = z * np.concatenate((acor[:0:-1].conj(), acor))

    # If requested, add noise.
    if measurement_error == 'gaussian-noise':
        sigmasq = 4 * np.sum(HS.deltaF * np.abs(HS.data.data))
        amp = 4 * n * HS.deltaF**0.5 * np.sqrt(HS.data.data / sigmasq)
        N = lal.CreateCOMPLEX16FrequencySeries(
            '', HS.epoch, HS.f0, HS.deltaF, HS.sampleUnits, n)
        N.data.data = amp * (
            np.random.randn(n) + 1j * np.random.randn(n))
        noise_term, sample_rate_2 = filter.autocorrelation(
            N, len(snr_series) / sample_rate, normalize=False)
        assert sample_rate == sample_rate_2
        snr_series += noise_term

    # Shift SNR series to the nearest sample.
    int_samples, frac_samples = divmod(
        (1e-9 * epoch.gpsNanoSeconds + toa) * sample_rate, 1)
    if frac_samples > 0.5:
        int_samples += 1
        frac_samples -= 1
    epoch = lal.LIGOTimeGPS(epoch.gpsSeconds, 0)
    n = len(acor) - 1
    mprime = np.arange(-n, n + 1)
    m = mprime + frac_samples
    re, im = (
        interp1d(m, x, kind='cubic', bounds_error=False, fill_value=0)(mprime)
        for x in (snr_series.real, snr_series.imag))
    snr_series = re + 1j * im

    # Find the trigger values.
    i_nearest = np.argmax(np.abs(snr_series[n-n//2:n+n//2+1])) + n-n//2
    i_interp, z_interp = interpolate_max(i_nearest, snr_series,
                                         n // 2, method='lanczos')
    toa = epoch + (int_samples + i_interp - n) / sample_rate
    snr = np.abs(z_interp)
    phase = np.angle(z_interp)

    # Shift and truncate the SNR time series.
    epoch += (int_samples + i_nearest - n - n // 2) / sample_rate
    snr_series = snr_series[(i_nearest - n // 2):(i_nearest + n // 2 + 1)]
    tseries = lal.CreateCOMPLEX8TimeSeries(
        'snr', epoch, 0, 1 / sample_rate,
        lal.DimensionlessUnit, len(snr_series))
    tseries.data.data = snr_series
    return horizon, snr, phase, toa, tseries


def simulate(seed, sim_inspiral, psds, responses, locations, measurement_error,
             f_low=None, f_high=None, waveform=None):
    from ..bayestar import filter

    np.random.seed(seed)

    # Unpack some values from the row in the table.
    DL = sim_inspiral.distance
    ra = sim_inspiral.longitude
    dec = sim_inspiral.latitude
    inc = sim_inspiral.inclination
    # phi = sim_inspiral.coa_phase  # arbitrary?
    psi = sim_inspiral.polarization
    epoch = sim_inspiral.time_geocent
    gmst = lal.GreenwichMeanSiderealTime(epoch)

    f_low = f_low or sim_inspiral.f_lower
    waveform = waveform or sim_inspiral.waveform

    # Signal models for each detector.
    H = filter.sngl_inspiral_psd(
        waveform,
        mass1=sim_inspiral.mass1,
        mass2=sim_inspiral.mass2,
        spin1x=sim_inspiral.spin1x,
        spin1y=sim_inspiral.spin1y,
        spin1z=sim_inspiral.spin1z,
        spin2x=sim_inspiral.spin2x,
        spin2y=sim_inspiral.spin2y,
        spin2z=sim_inspiral.spin2z,
        f_min=f_low,
        f_final=f_high)

    return [
        simulate_snr(
            ra, dec, psi, inc, DL, epoch, gmst, H, S, response, location,
            measurement_error=measurement_error)
        for S, response, location in zip(psds, responses, locations)]


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    # LIGO-LW XML imports.
    from ligo.lw import ligolw
    from ligo.lw.param import Param
    from ligo.lw.utils.search_summary import append_search_summary
    from ligo.lw import utils as ligolw_utils
    from ligo.lw.lsctables import (
        New, CoincDefTable, CoincID, CoincInspiralTable, CoincMapTable,
        CoincTable, ProcessParamsTable, ProcessTable, SimInspiralTable,
        SnglInspiralTable, TimeSlideTable)

    # glue, LAL and pylal imports.
    from ligo import segments
    import lal
    import lal.series
    import lalsimulation
    from lalinspiral.inspinjfind import InspiralSCExactCoincDef
    from lalinspiral.thinca import InspiralCoincDef
    from tqdm import tqdm

    # BAYESTAR imports.
    from ..io.events.ligolw import ContentHandler
    from ..bayestar import filter
    from ..util.progress import progress_map

    # Read PSDs.
    xmldoc = ligolw_utils.load_fileobj(
        opts.reference_psd, contenthandler=lal.series.PSDContentHandler)
    psds = lal.series.read_psd_xmldoc(xmldoc, root_name=None)
    psds = {
        key: filter.InterpolatedPSD(filter.abscissa(psd), psd.data.data)
        for key, psd in psds.items() if psd is not None}
    psds = [psds[ifo] for ifo in opts.detector]

    # Extract simulation table from injection file.
    inj_xmldoc = ligolw_utils.load_fileobj(
        opts.input, contenthandler=ContentHandler)
    orig_sim_inspiral_table = SimInspiralTable.get_table(inj_xmldoc)

    # Prune injections that are outside distance limits.
    orig_sim_inspiral_table[:] = [
        row for row in orig_sim_inspiral_table
        if opts.min_distance <= row.distance <= opts.max_distance]

    # Open output file.
    xmldoc = ligolw.Document()
    xmlroot = xmldoc.appendChild(ligolw.LIGO_LW())

    # Create tables. Process and ProcessParams tables are copied from the
    # injection file.
    coinc_def_table = xmlroot.appendChild(New(CoincDefTable))
    coinc_inspiral_table = xmlroot.appendChild(New(CoincInspiralTable))
    coinc_map_table = xmlroot.appendChild(New(CoincMapTable))
    coinc_table = xmlroot.appendChild(New(CoincTable))
    xmlroot.appendChild(ProcessParamsTable.get_table(inj_xmldoc))
    xmlroot.appendChild(ProcessTable.get_table(inj_xmldoc))
    sim_inspiral_table = xmlroot.appendChild(New(SimInspiralTable))
    sngl_inspiral_table = xmlroot.appendChild(New(SnglInspiralTable))
    time_slide_table = xmlroot.appendChild(New(TimeSlideTable))

    # Write process metadata to output file.
    process = register_to_xmldoc(
        xmldoc, p, opts, instruments=opts.detector,
        comment="Simulated coincidences")

    # Add search summary to output file.
    all_time = segments.segment([lal.LIGOTimeGPS(0), lal.LIGOTimeGPS(2e9)])
    append_search_summary(xmldoc, process, inseg=all_time, outseg=all_time)

    # Create a time slide entry.  Needed for coinc_event rows.
    time_slide_id = time_slide_table.get_time_slide_id(
        {ifo: 0 for ifo in opts.detector}, create_new=process)

    # Populate CoincDef table.
    inspiral_coinc_def = copy.copy(InspiralCoincDef)
    inspiral_coinc_def.coinc_def_id = coinc_def_table.get_next_id()
    coinc_def_table.append(inspiral_coinc_def)
    found_coinc_def = copy.copy(InspiralSCExactCoincDef)
    found_coinc_def.coinc_def_id = coinc_def_table.get_next_id()
    coinc_def_table.append(found_coinc_def)

    # Precompute values that are common to all simulations.
    detectors = [lalsimulation.DetectorPrefixToLALDetector(ifo)
                 for ifo in opts.detector]
    responses = [det.response for det in detectors]
    locations = [det.location for det in detectors]

    func = functools.partial(simulate, psds=psds,
                             responses=responses, locations=locations,
                             measurement_error=opts.measurement_error,
                             f_low=opts.f_low, f_high=opts.f_high,
                             waveform=opts.waveform)

    # Make sure that each thread gets a different random number state.
    # We start by drawing a random integer s in the main thread, and
    # then the i'th subprocess will seed itself with the integer i + s.
    #
    # The seed must be an unsigned 32-bit integer, so if there are n
    # threads, then s must be drawn from the interval [0, 2**32 - n).
    #
    # Note that *we* are thread 0, so there are a total of
    # n=1+len(sim_inspiral_table) threads.
    seed = np.random.randint(0, 2 ** 32 - len(sim_inspiral_table) - 1)
    np.random.seed(seed)

    with tqdm(desc='accepted') as progress:
        for sim_inspiral, simulation in zip(
                orig_sim_inspiral_table,
                progress_map(
                    func,
                    np.arange(len(orig_sim_inspiral_table)) + seed + 1,
                    orig_sim_inspiral_table, jobs=opts.jobs)):

            sngl_inspirals = []
            used_snr_series = []
            net_snr = 0.0
            count_triggers = 0

            # Loop over individual detectors and create SnglInspiral entries.
            for ifo, (horizon, abs_snr, arg_snr, toa, series) \
                    in zip(opts.detector, simulation):

                if np.random.uniform() > opts.duty_cycle:
                    continue
                elif abs_snr >= opts.snr_threshold:
                    # If SNR < threshold, then the injection is not found.
                    # Skip it.
                    count_triggers += 1
                    net_snr += np.square(abs_snr)
                elif not opts.keep_subthreshold:
                    continue

                # Create SnglInspiral entry.
                used_snr_series.append(series)
                sngl_inspirals.append(
                    sngl_inspiral_table.RowType(**dict(
                        dict.fromkeys(sngl_inspiral_table.validcolumns, None),
                        process_id=process.process_id,
                        ifo=ifo,
                        mass1=sim_inspiral.mass1,
                        mass2=sim_inspiral.mass2,
                        spin1x=sim_inspiral.spin1x,
                        spin1y=sim_inspiral.spin1y,
                        spin1z=sim_inspiral.spin1z,
                        spin2x=sim_inspiral.spin2x,
                        spin2y=sim_inspiral.spin2y,
                        spin2z=sim_inspiral.spin2z,
                        end=toa,
                        snr=abs_snr,
                        coa_phase=arg_snr,
                        f_final=opts.f_high,
                        eff_distance=horizon / abs_snr)))

            net_snr = np.sqrt(net_snr)

            # If too few triggers were found, then skip this event.
            if count_triggers < opts.min_triggers:
                continue

            # If network SNR < threshold, then the injection is not found.
            # Skip it.
            if net_snr < opts.net_snr_threshold:
                continue

            # Add Coinc table entry.
            coinc = coinc_table.appendRow(
                coinc_event_id=coinc_table.get_next_id(),
                process_id=process.process_id,
                coinc_def_id=inspiral_coinc_def.coinc_def_id,
                time_slide_id=time_slide_id,
                insts=opts.detector,
                nevents=len(opts.detector),
                likelihood=None)

            # Add CoincInspiral table entry.
            coinc_inspiral_table.appendRow(
                coinc_event_id=coinc.coinc_event_id,
                instruments=[
                    sngl_inspiral.ifo for sngl_inspiral in sngl_inspirals],
                end=lal.LIGOTimeGPS(1e-9 * np.mean([
                    sngl_inspiral.end.ns()
                    for sngl_inspiral in sngl_inspirals
                    if sngl_inspiral.end is not None])),
                mass=sim_inspiral.mass1 + sim_inspiral.mass2,
                mchirp=sim_inspiral.mchirp,
                combined_far=0.0,  # Not provided
                false_alarm_rate=0.0,  # Not provided
                minimum_duration=None,  # Not provided
                snr=net_snr)

            # Record all sngl_inspiral records and associate them with coincs.
            for sngl_inspiral, series in zip(sngl_inspirals, used_snr_series):
                # Give this sngl_inspiral record an id and add it to the table.
                sngl_inspiral.event_id = sngl_inspiral_table.get_next_id()
                sngl_inspiral_table.append(sngl_inspiral)

                if opts.enable_snr_series:
                    elem = lal.series.build_COMPLEX8TimeSeries(series)
                    elem.appendChild(
                        Param.from_pyvalue('event_id', sngl_inspiral.event_id))
                    xmlroot.appendChild(elem)

                # Add CoincMap entry.
                coinc_map_table.appendRow(
                    coinc_event_id=coinc.coinc_event_id,
                    table_name=sngl_inspiral_table.tableName,
                    event_id=sngl_inspiral.event_id)

            # Record injection
            if not opts.preserve_ids:
                sim_inspiral.simulation_id = sim_inspiral_table.get_next_id()
            sim_inspiral_table.append(sim_inspiral)

            progress.update()

    # Record coincidence associating injections with events.
    for i, sim_inspiral in enumerate(sim_inspiral_table):
        coinc = coinc_table.appendRow(
            coinc_event_id=coinc_table.get_next_id(),
            process_id=process.process_id,
            coinc_def_id=found_coinc_def.coinc_def_id,
            time_slide_id=time_slide_id,
            instruments=None,
            nevents=None,
            likelihood=None)
        coinc_map_table.appendRow(
            coinc_event_id=coinc.coinc_event_id,
            table_name=sim_inspiral_table.tableName,
            event_id=sim_inspiral.simulation_id)
        coinc_map_table.appendRow(
            coinc_event_id=coinc.coinc_event_id,
            table_name=coinc_table.tableName,
            event_id=CoincID(i))

    # Record process end time.
    process.set_end_time_now()

    # Write output file.
    write_fileobj(xmldoc, opts.output)

ligo/skymap/tool/bayestar_sample_model_psd.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
#
# Copyright (C) 2014-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Construct a LIGO-LW XML power spectral density file for a network of
detectors by evaluating a model power noise sensitivity curve.
"""

from argparse import SUPPRESS
import inspect

from . import ArgumentParser, FileType, register_to_xmldoc, write_fileobj

psd_name_prefix = 'SimNoisePSD'


def parser():
    import lal
    import lalsimulation

    # Get names of PSD functions.
    psd_names = sorted(
        name[len(psd_name_prefix):]
        for name, func in inspect.getmembers(lalsimulation)
        if name.startswith(psd_name_prefix) and callable(func) and (
            '(double f) -> double' in func.__doc__ or
            '(REAL8FrequencySeries psd, double flow) -> int' in func.__doc__))

    parser = ArgumentParser()
    parser.add_argument(
        '-o', '--output', metavar='OUT.xml[.gz]', type=FileType('wb'),
        default='-', help='Name of output file [default: stdout]')
    parser.add_argument(
        '--df', metavar='Hz', type=float, default=1.0,
        help='Frequency step size [default: %(default)s]')
    parser.add_argument(
        '--f-max', metavar='Hz', type=float, default=2048.0,
        help='Maximum frequency [default: %(default)s]')

    detector_group = parser.add_argument_group(
        'detector noise curves',
        'Options to select noise curves for detectors.\n\n'
        'All detectors support the following options:\n\n' +
        '\n'.join(psd_names))

    scale_group = parser.add_argument_group(
        'detector scaling',
        'Options to apply scale factors to noise curves for detectors.\n'
        'For example, a scale factor of 2 means that the amplitude spectral\n'
        'density is multiplied by 1/2 so that the range is multiplied by a 2.')

    # Add options for individual detectors
    for detector in lal.CachedDetectors:
        name = detector.frDetector.name
        prefix = detector.frDetector.prefix
        detector_group.add_argument(
            '--' + prefix, choices=psd_names,
            metavar='func', default=SUPPRESS,
            help='PSD function for {0} detector'.format(name))
        scale_group.add_argument(
            '--' + prefix + '-scale', type=float, default=SUPPRESS,
            help='Scale range for {0} detector'.format(name))

    return parser


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    import lal.series
    import lalsimulation
    import numpy as np
    from ..bayestar.filter import vectorize_swig_psd_func

    # Add basic options.

    psds = {}

    n = int(opts.f_max // opts.df)
    f = np.arange(n) * opts.df

    detectors = [d.frDetector.prefix for d in lal.CachedDetectors]

    for detector in detectors:
        psd_name = getattr(opts, detector, None)
        if psd_name is None:
            continue
        scale = 1 / np.square(getattr(opts, detector + '_scale', 1.0))
        func = getattr(lalsimulation, psd_name_prefix + psd_name)
        series = lal.CreateREAL8FrequencySeries(
            psd_name, 0, 0, opts.df, lal.SecondUnit, n)
        if '(double f) -> double' in func.__doc__:
            series.data.data = vectorize_swig_psd_func(
                psd_name_prefix + psd_name)(f)
        else:
            func(series, 0.0)

            # Find indices of first and last nonzero samples.
            nonzero = np.flatnonzero(series.data.data)
            # FIXME: int cast seems to be needed on old versions of Numpy
            first_nonzero = int(nonzero[0])
            last_nonzero = int(nonzero[-1])

            # Truncate
            series = lal.CutREAL8FrequencySeries(
                series, first_nonzero, last_nonzero - first_nonzero + 1)
            series.f0 = first_nonzero * series.deltaF

            series.name = psd_name
        series.data.data *= scale
        psds[detector] = series

    xmldoc = lal.series.make_psd_xmldoc(psds)
    register_to_xmldoc(xmldoc, p, opts)
    write_fileobj(xmldoc, opts.output)

ligo/skymap/tool/ligo_skymap_combine.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
#
# Copyright (C) 2018-2020  Tito Dal Canton, Eric Burns, Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Combine different sky localizations of a common event observed by different
instruments in order to form a more constrained localization.

If one of the input maps contains distance information (for instance from
BAYESTAR or LALInference) then the marginal distance posterior in the output
map is updated according to the restriction in sky location imposed by the
other input map(s). Only one input map can currently have distance
information.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument('input', metavar='INPUT.fits[.gz]',
                        type=FileType('rb'), nargs='+',
                        help='Input sky localizations')
    # FIXME the output option has type str because astropy.io.fits.writeto()
    # only honors the .gz extension when given a file name string (as of 3.0.1)
    parser.add_argument('output', metavar='OUTPUT.fits[.gz]', type=str,
                        help='Output combined sky localization')
    parser.add_argument('--origin', type=str,
                        help='Optional tag describing the organization'
                             ' responsible for the combined output')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    from textwrap import wrap
    import numpy as np
    import astropy_healpix as ah
    from astropy.io import fits
    from astropy.time import Time
    import healpy as hp

    from ..distance import parameters_to_marginal_moments
    from ..io import read_sky_map, write_sky_map

    input_skymaps = []
    dist_mu = dist_sigma = dist_norm = None
    for input_file in args.input:
        with fits.open(input_file) as hdus:
            header = hdus[0].header.copy()
            header.extend(hdus[1].header)
            has_distance = 'DISTMU' in hdus[1].columns.names
            data, meta = read_sky_map(hdus, nest=True,
                                      distances=has_distance)

        if has_distance:
            if dist_mu is not None:
                raise RuntimeError('only one input localization can have'
                                   ' distance information')
            dist_mu = data[1]
            dist_sigma = data[2]
            dist_norm = data[3]
        else:
            data = (data,)

        nside = ah.npix_to_nside(len(data[0]))
        input_skymaps.append((nside, data[0], meta, header))

    max_nside = max(x[0] for x in input_skymaps)

    # upsample sky posteriors to maximum resolution and combine them
    combined_prob = None
    for nside, prob, _, _ in input_skymaps:
        if nside < max_nside:
            prob = hp.ud_grade(prob, max_nside, order_in='NESTED',
                               order_out='NESTED')
        if combined_prob is None:
            combined_prob = np.ones_like(prob)
        combined_prob *= prob

    # normalize joint posterior
    norm = combined_prob.sum()
    if norm == 0:
        raise RuntimeError('input sky localizations are disjoint')
    combined_prob /= norm

    out_kwargs = {'gps_creation_time': Time.now().gps,
                  'nest': True}
    if args.origin is not None:
        out_kwargs['origin'] = args.origin

    # average the various input event times
    input_gps = [x[2]['gps_time'] for x in input_skymaps if 'gps_time' in x[2]]
    if input_gps:
        out_kwargs['gps_time'] = np.mean(input_gps)

    # combine instrument tags
    out_instruments = set()
    for x in input_skymaps:
        if 'instruments' in x[2]:
            out_instruments.update(x[2]['instruments'])
    out_kwargs['instruments'] = ','.join(out_instruments)

    # update marginal distance posterior, if available
    if dist_mu is not None:
        if ah.npix_to_nside(len(dist_mu)) < max_nside:
            dist_mu = hp.ud_grade(dist_mu, max_nside, order_in='NESTED',
                                  order_out='NESTED')
            dist_sigma = hp.ud_grade(dist_sigma, max_nside, order_in='NESTED',
                                     order_out='NESTED')
            dist_norm = hp.ud_grade(dist_norm, max_nside, order_in='NESTED',
                                    order_out='NESTED')
        distmean, diststd = parameters_to_marginal_moments(combined_prob,
                                                           dist_mu,
                                                           dist_sigma)
        out_data = (combined_prob, dist_mu, dist_sigma, dist_norm)
        out_kwargs['distmean'] = distmean
        out_kwargs['diststd'] = diststd
    else:
        out_data = combined_prob

    # save input headers in output history
    out_kwargs['HISTORY'] = []
    for i, x in enumerate(input_skymaps):
        out_kwargs['HISTORY'].append('')
        out_kwargs['HISTORY'].append(
            'Headers of HDUs 0 and 1 of input file {:d}:'.format(i))
        out_kwargs['HISTORY'].append('')
        for line in x[3].tostring(sep='\n',
                                  endcard=False,
                                  padding=False).split('\n'):
            out_kwargs['HISTORY'].extend(wrap(line, 72))

    write_sky_map(args.output, out_data, **out_kwargs)

ligo/skymap/tool/ligo_skymap_constellations.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
List most likely constellations for a localization.

Just for fun, and for public outreach purposes.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '-o', '--output', metavar='OUT.dat', type=FileType('w'), default='-',
        help='Name of output file')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Late imports

    from ..io import fits
    import astropy_healpix as ah
    from astropy.coordinates import SkyCoord
    from astropy.table import Table
    from astropy import units as u
    import healpy as hp
    import numpy as np

    prob, meta = fits.read_sky_map(opts.input.name, nest=None)
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    ipix = np.arange(npix)
    ra, dec = hp.pix2ang(nside, ipix, lonlat=True, nest=meta['nest'])
    coord = SkyCoord(ra * u.deg, dec * u.deg)
    table = Table({'prob': prob, 'constellation': coord.get_constellation()},
                  copy=False)
    table = table.group_by('constellation').groups.aggregate(np.sum)
    table.sort('prob')
    table.reverse()
    table.write(opts.output, format='ascii.tab')

ligo/skymap/tool/ligo_skymap_contour.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
#
# Copyright (C) 2015-2018  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
Create a contours for the credible levels of an all-sky probability map.
The input is a HEALPix probability map.
The output is a GeoJSON FeatureCollection (http://geojson.org/).
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument(
        '-o', '--output', metavar='FILE.geojson',
        default='-', type=FileType('w'), help='output file [default: stdout]')
    parser.add_argument(
        '--contour', metavar='PERCENT', type=float, nargs='+', required=True,
        help='plot contour enclosing this percentage of probability mass')
    parser.add_argument(
        '-i', '--interpolate',
        choices='nearest nested bilinear'.split(), default='nearest',
        help='resampling interpolation method')
    parser.add_argument(
        '-s', '--simplify', action='store_true', help='simplify contour paths')
    parser.add_argument(
        '-n', '--nside', metavar='NSIDE', type=int,
        help='optionally resample to the specified resolution '
        ' before generating contours')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    import healpy as hp
    import numpy as np
    import json

    from ..io import fits
    from .. import postprocess

    # Read input file
    prob, _ = fits.read_sky_map(opts.input.name, nest=True)

    # Resample if requested
    if opts.nside is not None and opts.interpolate in ('nearest', 'nested'):
        prob = hp.ud_grade(prob, opts.nside, order_in='NESTED', power=-2)
    elif opts.nside is not None and opts.interpolate == 'bilinear':
        prob = postprocess.smooth_ud_grade(prob, opts.nside, nest=True)
    if opts.interpolate == 'nested':
        prob = postprocess.interpolate_nested(prob, nest=True)

    # Find credible levels
    i = np.flipud(np.argsort(prob))
    cumsum = np.cumsum(prob[i])
    cls = np.empty_like(prob)
    cls[i] = cumsum * 100

    # Generate contours
    paths = list(postprocess.contour(
        cls, opts.contour, nest=True, degrees=True, simplify=opts.simplify))

    json.dump({
        'type': 'FeatureCollection',
        'features': [
            {
                'type': 'Feature',
                'properties': {
                    'credible_level': contour
                },
                'geometry': {
                    'type': 'MultiLineString',
                    'coordinates': path
                }
            }
            for contour, path in zip(opts.contour, paths)
        ]
    }, opts.output)

ligo/skymap/tool/ligo_skymap_contour_moc.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
#
# Copyright (C) 2013-2023 Giuseppe Greco, Leo Singer, and CDS team.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
Create a contour for a credible level of an all-sky probability map. The input
is a HEALPix FITS probability map. The output is a `Multi-Order Coverage (MOC)
<http://ivoa.net/documents/MOC/>`_ FITS file.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()

    parser.add_argument(
        '-o', '--output', metavar='FILE.fits', required=True,
        help='output file')
    parser.add_argument(
        '-c', '--contour', metavar='PERCENT', type=float, required=True,
        help='MOC region enclosing this percentage of probability \
              [range is 0-100]')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input multi-order or flatten \
                                      HEALPix FITS file')

    return parser


def main(args=None):
    p = parser()
    opts = parser().parse_args(args)

    import astropy_healpix as ah
    import astropy.units as u

    try:
        from mocpy import MOC
    except ImportError:
        p.error('This command-line tool requires mocpy >= 0.8.2. '
                'Please install it by running "pip install mocpy".')

    from ..io import read_sky_map

    # Read multi-order sky map
    skymap = read_sky_map(opts.input.name, moc=True)

    uniq = skymap['UNIQ']
    probdensity = skymap['PROBDENSITY']

    level, ipix = ah.uniq_to_level_ipix(uniq)
    area = ah.nside_to_pixel_area(
        ah.level_to_nside(level)).to_value(u.steradian)

    prob = probdensity * area

    # Create MOC
    contour_decimal = opts.contour / 100
    moc = MOC.from_valued_healpix_cells(
        uniq, prob, max_depth=level.max(),
        cumul_from=0.0, cumul_to=contour_decimal)

    # Write MOC
    moc.write(opts.output, format='fits', overwrite=True)

ligo/skymap/tool/ligo_skymap_flatten.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Convert a HEALPix FITS file from multi-resolution UNIQ indexing to the more
common IMPLICIT indexing.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument('input', metavar='INPUT.fits[.gz]',
                        type=FileType('rb'), help='Input FITS file')
    parser.add_argument('output', metavar='OUTPUT.fits',
                        type=FileType('wb'), help='Output FITS file')
    parser.add_argument('--nside', type=int, help='Output HEALPix resolution')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    import logging
    import warnings
    import astropy_healpix as ah
    from astropy.io import fits
    from ..io import read_sky_map, write_sky_map
    from ..bayestar import rasterize

    log = logging.getLogger()

    if args.nside is None:
        order = None
    else:
        order = ah.nside_to_level(args.nside)

    log.info('reading FITS file %s', args.input.name)
    hdus = fits.open(args.input)
    ordering = hdus[1].header['ORDERING']
    expected_ordering = 'NUNIQ'
    if ordering != expected_ordering:
        msg = 'Expected the FITS file {} to have ordering {}, but it is {}'
        warnings.warn(msg.format(args.input.name, expected_ordering, ordering))
    log.debug('converting original FITS file to Astropy table')
    table = read_sky_map(hdus, moc=True)
    log.debug('flattening HEALPix tree')
    table = rasterize(table, order=order)
    log.info('writing FITS file %s', args.output.name)
    write_sky_map(args.output.name, table, nest=True)
    log.debug('done')

ligo/skymap/tool/ligo_skymap_from_samples.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
#
# Copyright (C) 2011-2023  Will M. Farr <will.farr@ligo.org>
#                          Leo P. Singer <leo.singer@ligo.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Generate a FITS sky map file from posterior samples using clustering and
kernel density estimation.

The input file should be an HDF5 file with the following columns:

*  ``ra``, ``rightascension``, or ``right_ascension``: J2000 right ascension in
    radians
*  ``dec`` or ``declination``: J200 declination in radians
*  ``dist``, ``distance``, or ``luminosity_distance``: luminosity distance in
   Mpc (optional)

The output consist of two files:

*  ``skypost.obj``, a :mod:`pickle` representation of the kernel density
   estimator
*  ``skymap.fits.gz``, a 3D localization in HEALPix/FITS format
"""

from argparse import SUPPRESS

from . import (
    ArgumentParser, DirType, EnableAction, FileType, get_random_parser)


def parser():
    # Command line interface.
    parser = ArgumentParser(parents=[get_random_parser()])
    parser.add_argument('samples', type=FileType('rb'), metavar='SAMPLES.hdf5',
                        help='posterior samples file')
    # Only present for backward compatibility with --samples syntax
    parser.add_argument('--samples', action='store_false', dest='_ignored',
                        help=SUPPRESS)
    parser.add_argument('--outdir', '-o', default='.',
                        type=DirType(create=True), help='output directory')
    parser.add_argument('--fitsoutname', default='skymap.fits',
                        metavar='SKYMAP.fits[.gz]',
                        help='filename for the FITS file')
    parser.add_argument('--loadpost', type=FileType('rb'),
                        metavar='SKYPOST.obj',
                        help='filename for pickled posterior state')
    parser.add_argument('--maxpts', type=int,
                        help='maximum number of posterior points to use; if '
                        'omitted or greater than or equal to the number of '
                        'posterior samples, then use all samples')
    parser.add_argument('--trials', type=int, default=5,
                        help='number of trials at each clustering number')
    parser.add_argument('--enable-distance-map', action=EnableAction,
                        help='generate HEALPix map of distance estimates')
    parser.add_argument('--enable-multiresolution', action=EnableAction,
                        default=True,
                        help='generate a multiresolution HEALPix map')
    parser.add_argument('--top-nside', type=int, default=16,
                        help='choose a start nside before HEALPix refinement '
                        'steps (must be a valid nside)')
    parser.add_argument('-j', '--jobs', type=int, default=1, const=None,
                        nargs='?', help='Number of threads')
    parser.add_argument('--instruments', metavar='H1|L1|V1|...', nargs='+',
                        help='instruments to store in FITS header')
    parser.add_argument('--objid', help='event ID to store in FITS header')
    parser.add_argument('--path', type=str, default=None,
                        help="The path of the dataset within the HDF5 file")
    parser.add_argument('--tablename', type=str, default='posterior_samples',
                        help='The name of the table to search for recursively '
                        'within the HDF5 file. By default, search for '
                        'posterior_samples')
    return parser


def main(args=None):
    _parser = parser()
    args = _parser.parse_args(args)

    # Late imports
    from .. import io
    from ..io.hdf5 import _remap_colnames
    from ..bayestar import rasterize
    from .. import version
    from astropy.table import Table
    from astropy.time import Time
    import numpy as np
    import os
    import sys
    import pickle
    from ..kde import Clustered2Plus1DSkyKDE, Clustered2DSkyKDE
    import logging
    from textwrap import wrap

    log = logging.getLogger()

    log.info('reading samples')
    try:
        data = io.read_samples(args.samples.name, path=args.path,
                               tablename=args.tablename)
    except IOError:
        # FIXME: remove this code path once we support only HDF5
        data = Table.read(args.samples, format='ascii')
        _remap_colnames(data)

    if args.maxpts is not None and args.maxpts < len(data):
        log.info('taking random subsample of chain')
        data = data[np.random.choice(len(data), args.maxpts, replace=False)]
    try:
        dist = data['dist']
    except KeyError:
        try:
            dist = data['distance']
        except KeyError:
            dist = None

    if args.loadpost is None:
        if dist is None:
            if args.enable_distance_map:
                _parser.error("The posterior samples file '{}' does not have "
                              "a distance column named 'dist' or 'distance'. "
                              "Cannot generate distance map. If you do not "
                              "intend to generate a distance map, then add "
                              "the '--disable-distance-map' command line "
                              "argument.".format(args.samples.name))
            pts = np.column_stack((data['ra'], data['dec']))
        else:
            pts = np.column_stack((data['ra'], data['dec'], dist))
        if args.enable_distance_map:
            cls = Clustered2Plus1DSkyKDE
        else:
            cls = Clustered2DSkyKDE
        skypost = cls(pts, trials=args.trials, jobs=args.jobs)

        log.info('pickling')
        with open(os.path.join(args.outdir, 'skypost.obj'), 'wb') as out:
            pickle.dump(skypost, out)
    else:
        skypost = pickle.load(args.loadpost)
        skypost.jobs = args.jobs

    log.info('making skymap')
    hpmap = skypost.as_healpix(top_nside=args.top_nside)
    if not args.enable_multiresolution:
        hpmap = rasterize(hpmap)
    hpmap.meta.update(io.fits.metadata_for_version_module(version))
    hpmap.meta['creator'] = _parser.prog
    hpmap.meta['origin'] = 'LIGO/Virgo/KAGRA'
    hpmap.meta['gps_creation_time'] = Time.now().gps
    hpmap.meta['history'] = [
        '', 'Generated by running the following script:',
        *wrap(' '.join([_parser.prog] + sys.argv[1:]), 72)]
    if args.objid is not None:
        hpmap.meta['objid'] = args.objid
    if args.instruments:
        hpmap.meta['instruments'] = args.instruments
    if args.enable_distance_map:
        hpmap.meta['distmean'] = np.mean(dist)
        hpmap.meta['diststd'] = np.std(dist)

    keys = ['time', 'time_mean', 'time_maxl']
    for key in keys:
        try:
            time = data[key]
        except KeyError:
            continue
        else:
            hpmap.meta['gps_time'] = time.mean()
            break
    else:
        log.warning(
            'Cannot determine the event time from any of the columns %r', keys)

    io.write_sky_map(os.path.join(args.outdir, args.fitsoutname),
                     hpmap, nest=True)

ligo/skymap/tool/ligo_skymap_plot.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
#
# Copyright (C) 2011-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
Plot an all-sky map on a Mollweide projection.
By default, plot in celestial coordinates (RA, Dec).

To plot in geographic coordinates (longitude, latitude) with major
coastlines overlaid, provide the ``--geo`` flag.

Public-domain cartographic data is courtesy of `Natural Earth
<http://www.naturalearthdata.com>`_ and processed with `MapShaper
<http://www.mapshaper.org>`_.
"""

from . import ArgumentParser, FileType, SQLiteType
from .matplotlib import get_figure_parser


def parser():
    parser = ArgumentParser(parents=[get_figure_parser()])
    parser.add_argument(
        '--annotate', default=False, action='store_true',
        help='annotate plot with information about the event')
    parser.add_argument(
        '--contour', metavar='PERCENT', type=float, nargs='+',
        help='plot contour enclosing this percentage of'
        ' probability mass [may be specified multiple times, default: none]')
    parser.add_argument(
        '--colorbar', default=False, action='store_true',
        help='Show colorbar')
    parser.add_argument(
        '--radec', nargs=2, metavar='deg', type=float, action='append',
        default=[], help='right ascension (deg) and declination (deg) to mark')
    parser.add_argument(
        '--inj-database', metavar='FILE.sqlite', type=SQLiteType('r'),
        help='read injection positions from database')
    parser.add_argument(
        '--geo', action='store_true',
        help='Use a terrestrial reference frame with coordinates (lon, lat) '
             'instead of the celestial frame with coordinates (RA, dec) '
             'and draw continents on the map')
    parser.add_argument(
        '--projection', type=str, default='mollweide',
        choices=['mollweide', 'aitoff', 'globe', 'zoom'],
        help='Projection style')
    parser.add_argument(
        '--projection-center', metavar='CENTER',
        help='Specify the center for globe and zoom projections, e.g. 14h 10d')
    parser.add_argument(
        '--zoom-radius', metavar='RADIUS',
        help='Specify the radius for zoom projections, e.g. 4deg')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Late imports

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from ..io import fits
    from .. import plot
    from .. import postprocess
    import astropy_healpix as ah
    from astropy.coordinates import SkyCoord
    from astropy.time import Time
    from astropy import units as u

    skymap, metadata = fits.read_sky_map(opts.input.name, nest=None)
    nside = ah.npix_to_nside(len(skymap))

    # Convert sky map from probability to probability per square degree.
    deg2perpix = ah.nside_to_pixel_area(nside).to_value(u.deg**2)
    probperdeg2 = skymap / deg2perpix

    axes_args = {}
    if opts.geo:
        axes_args['projection'] = 'geo'
        obstime = Time(metadata['gps_time'], format='gps').utc.isot
        axes_args['obstime'] = obstime
    else:
        axes_args['projection'] = 'astro'
    axes_args['projection'] += ' ' + opts.projection
    if opts.projection_center is not None:
        axes_args['center'] = SkyCoord(opts.projection_center)
    if opts.zoom_radius is not None:
        axes_args['radius'] = opts.zoom_radius
    ax = plt.axes(**axes_args)
    ax.grid()

    # Plot sky map.
    vmax = probperdeg2.max()
    img = ax.imshow_hpx(
        (probperdeg2, 'ICRS'), nested=metadata['nest'], vmin=0., vmax=vmax)

    # Add colorbar.
    if opts.colorbar:
        cb = plot.colorbar(img)
        cb.set_label(r'prob. per deg$^2$')

    # Add contours.
    if opts.contour:
        cls = 100 * postprocess.find_greedy_credible_levels(skymap)
        cs = ax.contour_hpx(
            (cls, 'ICRS'), nested=metadata['nest'],
            colors='k', linewidths=0.5, levels=opts.contour)
        fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%'
        plt.clabel(cs, fmt=fmt, fontsize=6, inline=True)

    # Add continents.
    if opts.geo:
        plt.plot(*plot.coastlines(), color='0.5', linewidth=0.5,
                 transform=ax.get_transform('world'))

    radecs = opts.radec
    if opts.inj_database:
        query = '''SELECT DISTINCT longitude, latitude FROM sim_inspiral AS si
                   INNER JOIN coinc_event_map AS cm1
                   ON (si.simulation_id = cm1.event_id)
                   INNER JOIN coinc_event_map AS cm2
                   ON (cm1.coinc_event_id = cm2.coinc_event_id)
                   WHERE cm2.event_id = ?
                   AND cm1.table_name = 'sim_inspiral'
                   AND cm2.table_name = 'coinc_event'
                   '''
        (ra, dec), = opts.inj_database.execute(
            query, (metadata['objid'],)).fetchall()
        radecs.append(np.rad2deg([ra, dec]).tolist())

    # Add markers (e.g., for injections or external triggers).
    for ra, dec in radecs:
        ax.plot_coord(
            SkyCoord(ra, dec, unit='deg'), '*',
            markerfacecolor='white', markeredgecolor='black', markersize=10)

    # Add a white outline to all text to make it stand out from the background.
    plot.outline_text(ax)

    if opts.annotate:
        text = []
        try:
            objid = metadata['objid']
        except KeyError:
            pass
        else:
            text.append('event ID: {}'.format(objid))
        if opts.contour:
            pp = np.round(opts.contour).astype(int)
            ii = np.round(np.searchsorted(np.sort(cls), opts.contour) *
                          deg2perpix).astype(int)
            for i, p in zip(ii, pp):
                # FIXME: use Unicode symbol instead of TeX '$^2$'
                # because of broken fonts on Scientific Linux 7.
                text.append('{:d}% area: {:,d} deg²'.format(p, i))
        ax.text(1, 1, '\n'.join(text), transform=ax.transAxes, ha='right')

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_airmass.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
#
# Copyright (C) 2018-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Make an airmass chart for a LIGO/Virgo/KAGRA probability sky map."""

import numpy as np

from . import ArgumentParser, FileType, HelpChoicesAction
from .matplotlib import get_figure_parser


def parser():
    from astropy.coordinates import EarthLocation
    site_names = EarthLocation.get_site_names()
    parser = ArgumentParser(parents=[get_figure_parser()])
    parser.add_argument(
        '-v', '--verbose', action='store_true',
        help='Print airmass table to stdout')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '--time', help='UTC time')
    parser.add_argument(
        '--site', metavar='SITE', choices=site_names, help='Observatory site')
    parser.add_argument(
        '--help-site', action=HelpChoicesAction, choices=site_names)
    parser.add_argument(
        '--site-longitude', metavar='DEG', type=float,
        help='Observatory longitude on the WGS84 ellipsoid. '
        'Mutually exclusive with --site.')
    parser.add_argument(
        '--site-latitude', metavar='DEG', type=float,
        help='Observatory latitude on the WGS84 ellipsoid. '
        'Mutually exclusive with --site.')
    parser.add_argument(
        '--site-height', metavar='METERS', type=float,
        help='Observatory height from the WGS84 ellipsoid. '
        'Mutually exclusive with --site.')
    parser.add_argument(
        '--site-timezone',
        help='Observatory time zone, e.g. "US/Pacific". '
        'Mutually exclusive with --site.')
    return parser


def condition_secz(x):
    """Condition secz airmass formula: values <= 0 are below the horizon,
    which we map to infinite airmass.
    """
    return np.where(x <= 0, np.inf, x)


def clip_verylarge(x, max=1e300):
    return np.clip(x, -max, max)


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    # Late imports
    import operator
    import sys

    from astroplan import Observer
    from astroplan.plots import plot_airmass
    from astropy.coordinates import EarthLocation, SkyCoord
    from astropy.table import Table
    from astropy.time import Time
    from astropy import units as u
    from matplotlib import dates
    from matplotlib.cm import ScalarMappable
    from matplotlib.colors import Normalize
    from matplotlib.patches import Patch
    from matplotlib import pyplot as plt
    from tqdm import tqdm
    import pytz

    from ..io import fits
    from .. import moc
    from .. import plot  # noqa
    from ..extern.numpy.quantile import percentile

    if opts.site is None:
        if opts.site_longitude is None or opts.site_latitude is None:
            p.error('must specify either --site or both '
                    '--site-longitude and --site-latitude')
        location = EarthLocation(
            lon=opts.site_longitude * u.deg,
            lat=opts.site_latitude * u.deg,
            height=(opts.site_height or 0) * u.m)
        if opts.site_timezone is not None:
            location.info.meta = {'timezone': opts.site_timezone}
        observer = Observer(location)
    else:
        if not ((opts.site_longitude is None) and
                (opts.site_latitude is None) and
                (opts.site_height is None) and
                (opts.site_timezone is None)):
            p.error('argument --site not allowed with arguments '
                    '--site-longitude, --site-latitude, '
                    '--site-height, or --site-timezone')
        observer = Observer.at_site(opts.site)

    m = fits.read_sky_map(opts.input.name, moc=True)

    # Make an empty airmass chart.
    t0 = Time(opts.time) if opts.time is not None else Time.now()
    t0 = observer.midnight(t0)
    ax = plot_airmass([], observer, t0, altitude_yaxis=True)

    # Remove the fake source and determine times that were used for the plot.
    for artist in ax.lines:
        artist.remove()
    times = Time(np.linspace(*ax.get_xlim()), format='plot_date')

    theta, phi = moc.uniq2ang(m['UNIQ'])
    coords = SkyCoord(phi, 0.5 * np.pi - theta, unit='rad')
    prob = moc.uniq2pixarea(m['UNIQ']) * m['PROBDENSITY']

    levels = np.arange(90, 0, -10)
    nlevels = len(levels)
    percentiles = np.concatenate((50 - 0.5 * levels, 50 + 0.5 * levels))

    airmass = np.column_stack([
        percentile(
            condition_secz(coords.transform_to(observer.altaz(t)).secz),
            percentiles,
            weights=prob)
        for t in tqdm(times)])

    cmap = ScalarMappable(Normalize(0, 100), plt.get_cmap())
    for level, lo, hi in zip(levels, airmass[:nlevels], airmass[nlevels:]):
        ax.fill_between(
            times.plot_date,
            clip_verylarge(lo),  # Clip infinities to large but finite values
            clip_verylarge(hi),  # because fill_between cannot handle inf
            color=cmap.to_rgba(level), zorder=2)

    ax.legend(
        [Patch(facecolor=cmap.to_rgba(level)) for level in levels],
        ['{}%'.format(level) for level in levels])
    # ax.set_title('{} from {}'.format(m.meta['objid'], observer.name))

    # Adapted from astroplan
    start = times[0]
    twilights = [
        (times[0].datetime, 0.0),
        (observer.sun_set_time(
            Time(start), which='next').datetime, 0.0),
        (observer.twilight_evening_civil(
            Time(start), which='next').datetime, 0.1),
        (observer.twilight_evening_nautical(
            Time(start), which='next').datetime, 0.2),
        (observer.twilight_evening_astronomical(
            Time(start), which='next').datetime, 0.3),
        (observer.twilight_morning_astronomical(
            Time(start), which='next').datetime, 0.4),
        (observer.twilight_morning_nautical(
            Time(start), which='next').datetime, 0.3),
        (observer.twilight_morning_civil(
            Time(start), which='next').datetime, 0.2),
        (observer.sun_rise_time(
            Time(start), which='next').datetime, 0.1),
        (times[-1].datetime, 0.0),
    ]

    twilights.sort(key=operator.itemgetter(0))
    for i, twi in enumerate(twilights[1:], 1):
        if twi[1] != 0:
            ax.axvspan(twilights[i - 1][0], twilights[i][0],
                       ymin=0, ymax=1, color='grey', alpha=twi[1], linewidth=0)
        if twi[1] != 0.4:
            ax.axvspan(twilights[i - 1][0], twilights[i][0],
                       ymin=0, ymax=1, color='white', alpha=0.8 - 2 * twi[1],
                       zorder=3, linewidth=0)

    # Add local time axis
    timezone = (observer.location.info.meta or {}).get('timezone')
    if timezone:
        tzinfo = pytz.timezone(timezone)
        ax2 = ax.twiny()
        ax2.set_xlim(ax.get_xlim())
        ax2.set_xticks(ax.get_xticks())
        ax2.xaxis.set_major_formatter(dates.DateFormatter('%H:%M', tz=tzinfo))
        plt.setp(ax2.get_xticklabels(), rotation=-30, ha='right')
        ax2.set_xlabel("Time from {} [{}]".format(
            min(times).to_datetime(tzinfo).date(),
            timezone))

    if opts.verbose:
        # Write airmass table to stdout.
        times.format = 'isot'
        table = Table(masked=True)
        table['time'] = times
        table['sun_alt'] = np.ma.masked_greater_equal(
            observer.sun_altaz(times).alt, 0)
        table['sun_alt'].format = lambda x: '{}'.format(int(np.round(x)))
        for p, data in sorted(zip(percentiles, airmass)):
            table[str(p)] = np.ma.masked_invalid(data)
            table[str(p)].format = lambda x: '{:.01f}'.format(np.around(x, 1))
        table.write(sys.stdout, format='ascii.fixed_width')

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_coherence.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
#
# Copyright (C) 2011-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
Show a sky map's Bayes factor for coherence vs. incoherence as a bullet chart.
"""

from . import ArgumentParser, FileType
from .matplotlib import get_figure_parser


def parser():
    parser = ArgumentParser(parents=[get_figure_parser()])
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.set_defaults(colormap='RdYlBu')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Late imports

    from astropy.io import fits
    import numpy as np
    from ..plot import plot_bayes_factor

    header = fits.getheader(opts.input, 1)
    logb = header['LOGBCI']
    objid = header.get('OBJECT')

    title = 'Coherence'
    if objid:
        title += f' of {objid}'
    logb_string = np.format_float_positional(logb, 1, trim='0', sign=True)
    title += fr' $[\ln\,B = {logb_string}]$'

    plot_bayes_factor(logb, title=title, palette=opts.colormap)

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_observability.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Make an observability chart for a LIGO/Virgo/KAGRA probability sky map."""

import numpy as np

from . import ArgumentParser, FileType, HelpChoicesAction
from .matplotlib import get_figure_parser


def parser():
    from astropy.coordinates import EarthLocation
    site_names = EarthLocation.get_site_names()
    parser = ArgumentParser(parents=[get_figure_parser()])
    parser.add_argument(
        '-v', '--verbose', action='store_true',
        help='Print airmass table to stdout')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '--time', help='UTC time')
    parser.add_argument(
        '--max-airmass', default=2.5, type=float, help='Maximum airmass')
    parser.add_argument(
        '--twilight', default='astronomical',
        choices=('astronomical', 'nautical', 'civil'),
        help='Twilight definition: astronomical (-18 degrees), '
        'nautical (-12 degrees), or civil (-6 degrees)')
    parser.add_argument(
        '--site', nargs='*', default=[], metavar='SITE', choices=site_names,
        help='Observatory site')
    parser.add_argument(
        '--help-site', action=HelpChoicesAction, choices=site_names)
    parser.add_argument(
        '--site-name', nargs='*', default=[], help='Observatory name.')
    parser.add_argument(
        '--site-longitude', nargs='*', default=[], metavar='DEG', type=float,
        help='Observatory longitude on the WGS84 ellipsoid.')
    parser.add_argument(
        '--site-latitude', nargs='*', default=[], metavar='DEG', type=float,
        help='Observatory latitude on the WGS84 ellipsoid.')
    parser.add_argument(
        '--site-height', nargs='*', default=[], metavar='METERS', type=float,
        help='Observatory height from the WGS84 ellipsoid.')
    return parser


def condition_secz(x):
    """Condition secz airmass formula: values <= 0 are below the horizon,
    which we map to infinite airmass.
    """
    return np.where(x <= 0, np.inf, x)


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    # Late imports
    from astroplan import (
        AirmassConstraint, AtNightConstraint, Observer, is_event_observable)
    from astropy.coordinates import EarthLocation, SkyCoord
    from astropy.time import Time
    from astropy import units as u
    from matplotlib import dates
    from matplotlib import pyplot as plt
    from tqdm import tqdm

    from ..io import fits
    from .. import moc
    from .. import plot  # noqa

    names = ('name', 'longitude', 'latitude', 'height')
    length0, *lengths = (
        len(getattr(opts, 'site_{}'.format(name))) for name in names)
    if not all(length0 == length for length in lengths):
        p.error('these options require equal numbers of arguments: {}'.format(
            ', '.join(
                '--site-{}'.format(name) for name in names)))

    observers = [Observer.at_site(site) for site in opts.site]
    for name, lon, lat, height in zip(
            opts.site_name, opts.site_longitude, opts.site_latitude,
            opts.site_height):
        location = EarthLocation(
            lon=lon * u.deg,
            lat=lat * u.deg,
            height=(height or 0) * u.m)
        observers.append(Observer(location, name=name))
    observers = list(reversed(observers))

    m = fits.read_sky_map(opts.input.name, moc=True)

    t0 = Time(opts.time) if opts.time is not None else Time.now()
    times = t0 + np.linspace(0, 1) * u.day

    theta, phi = moc.uniq2ang(m['UNIQ'])
    coords = SkyCoord(phi, 0.5 * np.pi - theta, unit='rad')
    prob = np.asarray(moc.uniq2pixarea(m['UNIQ']) * m['PROBDENSITY'])

    constraints = [
        getattr(AtNightConstraint, 'twilight_{}'.format(opts.twilight))(),
        AirmassConstraint(opts.max_airmass)]

    fig = plt.figure()
    width, height = fig.get_size_inches()
    fig.set_size_inches(width, (len(observers) + 1) / 16 * width)
    ax = plt.axes()
    locator = dates.AutoDateLocator()
    formatter = dates.DateFormatter('%H:%M')
    ax.set_xlim([times[0].plot_date, times[-1].plot_date])
    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(locator)
    ax.set_xlabel("Time from {0} [UTC]".format(min(times).datetime.date()))
    plt.setp(ax.get_xticklabels(), rotation=30, ha='right')
    ax.set_yticks(np.arange(len(observers)))
    ax.set_yticklabels([observer.name for observer in observers])
    ax.yaxis.set_tick_params(left=False)
    ax.grid(axis='x')
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)

    for i, observer in enumerate(tqdm(observers)):
        observable = 100 * np.dot(prob, is_event_observable(
            constraints, observer, coords, times))
        ax.contourf(
            times.plot_date, [i - 0.4, i + 0.4], np.tile(observable, (2, 1)),
            levels=np.arange(10, 110, 10), cmap=plt.get_cmap().reversed())

    plt.tight_layout()

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_pp_samples.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
#
# Copyright (C) 2017-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Create a P-P plot to compare a posterior sample chain with a sky map."""

import re

import numpy as np

from . import ArgumentParser, FileType
from .matplotlib import get_figure_parser


def fmt(x, sigfigs, force_scientific=False):
    """Round and format a number in scientific notation."""
    places = sigfigs - int(np.floor(np.log10(x)))
    x_rounded = np.around(x, places)
    if places <= 0 and not force_scientific:
        return '{:d}'.format(int(x_rounded))
    else:
        s = ('{:.' + str(sigfigs) + 'e}').format(x_rounded)
        return re.sub(r'^(.*)e\+?(-?)0*(\d+)$', r'$\1 \\times 10^{\2\3}$', s)


def parser():
    parser = ArgumentParser(parents=[get_figure_parser()])
    parser.add_argument(
        'skymap', metavar='SKYMAP.fits[.gz]', type=FileType('rb'),
        help='FITS sky map file')
    parser.add_argument(
        'samples', metavar='SAMPLES.hdf5', type=FileType('rb'),
        help='HDF5 posterior sample chain file')
    parser.add_argument(
        '-m', '--max-points', type=int,
        help='Maximum number of posterior sample points '
        '[default: all of them]')
    parser.add_argument(
        '-p', '--contour', default=[], nargs='+', type=float,
        metavar='PERCENT',
        help='Report the areas and volumes within the smallest contours '
        'containing this much probability.')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    # Late imports.
    from astropy.table import Table
    from matplotlib import pyplot as plt
    from scipy.interpolate import interp1d
    from .. import io
    from .. import plot as _  # noqa: F401
    from ..postprocess import find_injection_moc

    # Read input.
    skymap = io.read_sky_map(args.skymap.name, moc=True)
    chain = io.read_samples(args.samples.name)

    # If required, downselect to a smaller number of posterior samples.
    if args.max_points is not None:
        chain = Table(np.random.permutation(chain)[:args.max_points],
                      copy=False)

    # Calculate P-P plot.
    contours = np.asarray(args.contour)
    result = find_injection_moc(skymap,
                                chain['ra'], chain['dec'], chain['dist'],
                                contours=1e-2 * contours)

    # Make Matplotlib figure.
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='pp_plot')
    ax.add_diagonal()
    ax.add_series(result.searched_prob, label='R.A., Dec.')
    searched_area_func = interp1d(np.linspace(0, 1, len(chain)),
                                  np.sort(result.searched_area),
                                  bounds_error=False)
    if 'DISTMU' in skymap.colnames:
        ax.add_series(result.searched_prob_dist, label='Distance')
        ax.add_series(result.searched_prob_vol, label='Volume')
        searched_vol_func = interp1d(np.linspace(0, 1, len(chain)),
                                     np.sort(result.searched_vol),
                                     bounds_error=False)
    for p, area, vol in zip(
            args.contour, result.contour_areas, result.contour_vols):
        text = '{:g}%\n{} deg$^2$'.format(p, fmt(area, 2))
        if 'DISTMU' in skymap.colnames:
            text += '\n{} Mpc$^3$'.format(fmt(vol, 2, force_scientific=True))
        ax.annotate(
            text, (1e-2 * p, 1e-2 * p), (0, -150),
            xycoords='data', textcoords='offset points',
            horizontalalignment='right', backgroundcolor='white',
            arrowprops=dict(connectionstyle='bar,angle=0,fraction=0',
                            arrowstyle='-|>', linewidth=2, color='black'))
        area = searched_area_func(1e-2 * p)
        text = '{:g}%\n{} deg$^2$'.format(p, fmt(area, 2))
        if 'DISTMU' in skymap.colnames:
            vol = searched_vol_func(1e-2 * p)
            text += '\n{} Mpc$^3$'.format(fmt(vol, 2, force_scientific=True))
        ax.annotate(
            text, (1e-2 * p, 1e-2 * p), (-75, 0),
            xycoords='data', textcoords='offset points',
            horizontalalignment='right', verticalalignment='center',
            backgroundcolor='white',
            arrowprops=dict(connectionstyle='bar,angle=0,fraction=0',
                            arrowstyle='-|>', linewidth=2, color='black'))
    ax.set_xlabel('searched probability')
    ax.set_ylabel('cumulative fraction of posterior samples')
    ax.set_title(args.skymap.name)
    ax.legend()
    ax.grid()

    # Show or save output.
    args.output()

ligo/skymap/tool/ligo_skymap_plot_stats.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
#
# Copyright (C) 2013-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Create summary plots for sky maps of found injections, optionally binned
cumulatively by false alarm rate or SNR.
"""

import os

from . import ArgumentParser, FileType
from .matplotlib import MatplotlibFigureType


def parser():
    parser = ArgumentParser()
    parser.add_argument('--cumulative', action='store_true')
    parser.add_argument('--normed', action='store_true')
    parser.add_argument(
        '--group-by', choices=('far', 'snr'), metavar='far|snr',
        help='Group plots by false alarm rate (FAR) or ' +
        'signal to noise ratio (SNR)')
    parser.add_argument(
        '--pp-confidence-interval', type=float, metavar='PCT',
        default=95, help='If all inputs files have the same number of '
        'samples, overlay binomial confidence bands for this percentage on '
        'the P--P plot')
    parser.add_argument(
        '--format', default='pdf', help='Matplotlib format')
    parser.add_argument(
        'input', type=FileType('rb'), nargs='+',
        help='Name of input file generated by ligo-skymap-stats')
    parser.add_argument(
        '--output', '-o', default='.', help='output directory')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Imports.
    from astropy.table import Table
    import matplotlib
    matplotlib.use('agg')
    from matplotlib import pyplot as plt
    from matplotlib import rcParams
    import numpy as np
    from tqdm import tqdm
    from .. import plot  # noqa

    # Read in all of the datasets listed as positional command line arguments.
    datasets = [Table.read(file, format='ascii') for file in opts.input]

    # Determine plot colors and labels
    filenames = [file.name for file in opts.input]
    labels = [os.path.splitext(os.path.basename(f))[0] for f in filenames]
    if rcParams['text.usetex']:
        labels = [r'\verb/' + label + '/' for label in labels]
    rcParams['savefig.format'] = opts.format
    metadata = MatplotlibFigureType.get_savefig_metadata(opts.format)

    # Normalize column names
    for dataset in datasets:
        if 'p_value' in dataset.colnames:
            dataset.rename_column('p_value', 'searched_prob')

    if opts.group_by == 'far':

        def key_func(table):
            return -np.log10(table['far'])

        def key_to_dir(key):
            return 'far_1e{}'.format(-key)

        def key_to_title(key):
            return r'$\mathrm{{FAR}} \leq 10^{{{}}}$ Hz'.format(-key)

    elif opts.group_by == 'snr':

        def key_func(table):
            return table['snr']

        def key_to_dir(key):
            return 'snr_{}'.format(key)

        def key_to_title(key):
            return r'$\mathrm{{SNR}} \geq {}$'.format(key)

    else:

        def key_func(table):
            return np.zeros(len(table))

        def key_to_dir(key):
            return '.'

        def key_to_title(key):
            return 'All events'

    if opts.group_by is not None:
        missing = [filename for filename, dataset in zip(filenames, datasets)
                   if opts.group_by not in dataset.colnames]
        if missing:
            raise RuntimeError(
                'The following files had no "'
                + opts.group_by + '" column: ' + ' '.join(missing))

    for dataset in datasets:
        dataset['key'] = key_func(dataset)

    if opts.group_by is not None:
        invalid = [filename for filename, dataset in zip(filenames, datasets)
                   if not np.all(np.isfinite(dataset['key']))]
        if invalid:
            raise RuntimeError(
                'The following files had invalid values in the "'
                + opts.group_by + '" column: ' + ' '.join(invalid))

    keys = np.concatenate([dataset['key'] for dataset in datasets])

    histlabel = []
    if opts.cumulative:
        histlabel.append('cumulative')
    if opts.normed:
        histlabel.append('fraction')
    else:
        histlabel.append('number')
    histlabel.append('of injections')
    histlabel = ' '.join(histlabel)

    pp_plot_settings = [
        ['', 'searched posterior mass'],
        ['_dist', 'distance CDF at true distance'],
        ['_vol', 'searched volumetric probability']]
    hist_settings = [
        ['searched_area', 'searched_area (deg$^2$)'],
        ['searched_vol', 'searched volume (Mpc$^3$)'],
        ['offset', 'angle from true location and mode of posterior (deg)'],
        ['runtime', 'run time (s)']]

    keys = range(*np.floor([keys.min(), keys.max()+1]).astype(int))
    total = len(keys) * (len(pp_plot_settings) + len(hist_settings))
    with tqdm(total=total) as progress:
        for key in keys:
            filtered = [d[d['key'] >= key] for d in datasets]
            title = key_to_title(key)
            nsamples = {len(d) for d in filtered}
            if len(nsamples) == 1:
                nsamples, = nsamples
                title += ' ({} events)'.format(nsamples)
            else:
                nsamples = None

            subdir = os.path.join(opts.output, key_to_dir(key))
            os.makedirs(subdir, exist_ok=True)

            # Make several different kinds of P-P plots
            for suffix, xlabel in pp_plot_settings:
                colname = 'searched_prob' + suffix
                fig = plt.figure(figsize=(6, 6))
                ax = fig.add_subplot(111, projection='pp_plot')
                fig.subplots_adjust(bottom=0.15)
                ax.set_xlabel(xlabel)
                ax.set_ylabel('cumulative fraction of injections')
                ax.set_title(title)
                for d, label in zip(filtered, labels):
                    ax.add_series(d.columns.get(colname, []), label=label)
                ax.add_diagonal()
                if nsamples:
                    ax.add_confidence_band(
                        nsamples, 0.01 * opts.pp_confidence_interval)
                ax.grid()
                if len(filtered) > 1:
                    ax.legend(loc='lower right')
                fig.savefig(os.path.join(subdir, colname),
                            metadata=metadata)
                plt.close()
                progress.update()

            # Make several different kinds of histograms
            for colname, xlabel in hist_settings:
                fig = plt.figure(figsize=(6, 4.5))
                ax = fig.add_subplot(111)
                fig.subplots_adjust(bottom=0.15)
                ax.set_xscale('log')
                ax.set_xlabel(xlabel)
                ax.set_ylabel(histlabel)
                ax.set_title(title)
                values = np.concatenate(
                    [d.columns.get(colname, []) for d in filtered])
                if len(values) > 0:
                    bins = np.geomspace(np.min(values), np.max(values),
                                        1000 if opts.cumulative else 20)
                    for d, label in zip(filtered, labels):
                        ax.hist(d.columns.get(colname, []),
                                cumulative=opts.cumulative,
                                density=opts.normed, histtype='step',
                                bins=bins, label=label)
                ax.grid()
                ax.legend(loc='upper left')
                fig.savefig(os.path.join(subdir, colname + '_hist'),
                            metadata=metadata)
                plt.close()
                progress.update()

ligo/skymap/tool/ligo_skymap_plot_volume.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
#
# Copyright (C) 2015-2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Plot a volumetric posterior in three-projection view."""


from . import ArgumentParser, FileType
from .matplotlib import get_figure_parser


def parser():
    parser = ArgumentParser(parents=[get_figure_parser()])
    parser.add_argument(
        '--annotate', default=False, action='store_true',
        help='annotate plot with information about the event')
    parser.add_argument(
        '--max-distance', metavar='Mpc', type=float,
        help='maximum distance of plot in Mpc')
    parser.add_argument(
        '--contour', metavar='PERCENT', type=float, nargs='+',
        help='plot contour enclosing this percentage of'
        ' probability mass')
    parser.add_argument(
        '--radecdist', nargs=3, type=float, action='append', default=[],
        help='right ascension (deg), declination (deg), and distance to mark')
    parser.add_argument(
        '--chain', metavar='CHAIN.hdf5', type=FileType('rb'),
        help='optionally plot a posterior sample chain')
    parser.add_argument(
        '--projection', type=int, choices=list(range(4)), default=0,
        help='Plot one specific projection, or 0 for all projections')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '--align-to', metavar='SKYMAP.fits[.gz]', type=FileType('rb'),
        help='Align to the principal axes of this sky map')
    parser.set_defaults(figure_width='3.5', figure_height='3.5')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Create progress bar.
    from tqdm import tqdm
    progress = tqdm()
    progress.set_description('Starting up')

    # Late imports
    import astropy_healpix as ah
    from matplotlib import pyplot as plt
    from matplotlib import gridspec
    from matplotlib import transforms
    from .. import io
    from ..plot import marker
    from ..distance import (parameters_to_marginal_moments, principal_axes,
                            volume_render, marginal_pdf)
    import healpy as hp
    import numpy as np
    import scipy.stats

    # Read input, determine input resolution.
    progress.set_description('Loading FITS file')
    (prob, mu, sigma, norm), metadata = io.read_sky_map(
        opts.input.name, distances=True)
    npix = len(prob)
    nside = ah.npix_to_nside(npix)

    progress.set_description('Preparing projection')

    if opts.align_to is None or opts.input.name == opts.align_to.name:
        prob2, mu2, sigma2, norm2 = prob, mu, sigma, norm
    else:
        (prob2, mu2, sigma2, norm2), _ = io.read_sky_map(
            opts.align_to.name, distances=True)
    if opts.max_distance is None:
        mean, std = parameters_to_marginal_moments(prob2, mu2, sigma2)
        max_distance = mean + 2.5 * std
    else:
        max_distance = opts.max_distance
    rot = np.ascontiguousarray(principal_axes(prob2, mu2, sigma2))

    if opts.chain:
        chain = io.read_samples(opts.chain.name)
        chain = np.dot(rot.T, (hp.ang2vec(
            0.5 * np.pi - chain['dec'], chain['ra']) *
            np.atleast_2d(chain['dist']).T).T)

    fig = plt.figure()
    n = 1 if opts.projection else 2
    gs = gridspec.GridSpec(
        n, n, left=0.01, right=0.99, bottom=0.01, top=0.99,
        wspace=0.05, hspace=0.05)

    imgwidth = int(opts.dpi * opts.figure_width / n)
    s = np.linspace(-max_distance, max_distance, imgwidth)
    xx, yy = np.meshgrid(s, s)

    truth_marker = marker.reticle(
        inner=0.5 * np.sqrt(2), outer=1.5 * np.sqrt(2), angle=45)

    for iface, (axis0, axis1, (sp0, sp1)) in enumerate((
            (1, 0, [0, 0]),
            (0, 2, [1, 1]),
            (1, 2, [1, 0]),)):

        if opts.projection and opts.projection != iface + 1:
            continue

        progress.set_description('Plotting projection {0}'.format(iface + 1))

        # Marginalize onto the given face
        density = volume_render(
            xx.ravel(), yy.ravel(), max_distance, axis0, axis1, rot, False,
            prob, mu, sigma, norm).reshape(xx.shape)

        # Plot heat map
        ax = fig.add_subplot(
            gs[0, 0] if opts.projection else gs[sp0, sp1], aspect=1)
        ax.imshow(
            density, origin='lower',
            extent=[-max_distance, max_distance, -max_distance, max_distance],
            cmap=opts.colormap)

        # Add contours if requested
        if opts.contour:
            flattened_density = density.ravel()
            indices = np.argsort(flattened_density)[::-1]
            cumsum = np.empty_like(flattened_density)
            cs = np.cumsum(flattened_density[indices])
            cumsum[indices] = cs / cs[-1] * 100
            cumsum = np.reshape(cumsum, density.shape)
            u, v = np.meshgrid(s, s)
            contourset = ax.contour(
                u, v, cumsum, levels=opts.contour, linewidths=0.5)

        # Mark locations
        ax._get_lines.get_next_color()  # skip default color
        for ra, dec, dist in opts.radecdist:
            theta = 0.5 * np.pi - np.deg2rad(dec)
            phi = np.deg2rad(ra)
            xyz = np.dot(rot.T, hp.ang2vec(theta, phi) * dist)
            ax.plot(
                xyz[axis0], xyz[axis1], marker=truth_marker,
                markerfacecolor='none', markeredgewidth=1)

        # Plot chain
        if opts.chain:
            ax.plot(chain[axis0], chain[axis1], '.k', markersize=0.5)

        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])

        # Set axis limits
        ax.set_xlim([-max_distance, max_distance])
        ax.set_ylim([-max_distance, max_distance])

        # Mark origin (Earth)
        ax.plot(
            [0], [0], marker=marker.earth, markersize=5,
            markerfacecolor='none', markeredgecolor='black',
            markeredgewidth=0.75)

        if iface == 2:
            ax.invert_xaxis()

    # Add contour labels if contours requested
    if opts.contour:
        ax.clabel(contourset, fmt='%d%%', fontsize=7)

    if not opts.projection:
        # Add scale bar, 1/4 width of the plot
        ax.plot(
            [0.0625, 0.3125], [0.0625, 0.0625],
            color='black', linewidth=1, transform=ax.transAxes)
        ax.text(
            0.0625, 0.0625,
            '{0:d} Mpc'.format(int(np.round(0.5 * max_distance))),
            fontsize=8, transform=ax.transAxes, verticalalignment='bottom')

        # Create marginal distance plot.
        progress.set_description('Plotting distance')
        gs1 = gridspec.GridSpecFromSubplotSpec(5, 5, gs[0, 1])
        ax = fig.add_subplot(gs1[1:-1, 1:-1])

        # Plot marginal distance distribution, integrated over the whole sky.
        d = np.linspace(0, max_distance)
        ax.fill_between(d, marginal_pdf(d, prob, mu, sigma, norm),
                        alpha=0.5, color=ax._get_lines.get_next_color())

        # Plot conditional distance distribution at true position
        # and mark true distance.
        for ra, dec, dist in opts.radecdist:
            theta = 0.5 * np.pi - np.deg2rad(dec)
            phi = np.deg2rad(ra)
            ipix = hp.ang2pix(nside, theta, phi)
            lines, = ax.plot(
                [dist], [-0.15], marker=truth_marker,
                markerfacecolor='none', markeredgewidth=1, clip_on=False,
                transform=transforms.blended_transform_factory(
                    ax.transData, ax.transAxes))
            ax.fill_between(d, scipy.stats.norm(
                mu[ipix], sigma[ipix]).pdf(d) * norm[ipix] * np.square(d),
                alpha=0.5, color=lines.get_color())
            ax.axvline(dist, color='black', linewidth=0.5)

        # Scale axes
        ax.set_xticks([0, max_distance])
        ax.set_xticklabels(
            ['0', "{0:d}\nMpc".format(int(np.round(max_distance)))],
            fontsize=9)
        ax.set_yticks([])
        ax.set_xlim(0, max_distance)
        ax.set_ylim(0, ax.get_ylim()[1])

        if opts.annotate:
            text = []
            try:
                objid = metadata['objid']
            except KeyError:
                pass
            else:
                text.append('event ID: {}'.format(objid))
            try:
                distmean = metadata['distmean']
                diststd = metadata['diststd']
            except KeyError:
                pass
            else:
                text.append('distance: {}±{} Mpc'.format(
                            int(np.round(distmean)), int(np.round(diststd))))
            ax.text(0, 1, '\n'.join(text), transform=ax.transAxes, fontsize=7,
                    ha='left', va='bottom', clip_on=False)

    progress.set_description('Saving')
    opts.output()

ligo/skymap/tool/ligo_skymap_stats.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Calculate summary statistics for a batch of sky maps.

Under the hood, this script is little more than a command-line interface for
the :mod:`ligo.skymap.postprocess.crossmatch` module.

The filenames of the sky maps may be provided as positional command line
arguments, and may also be provided as globs (such as ``*.fits.gz``). If
supplied with the optional ``--database`` argument, then also match sky maps
with injections from an inspinjfind-style sqlite database.

All angular separations are in degrees, all areas are in square degrees, and
all volumes are in cubic megaparsecs. The output is written as tab-separated
values with the following columns:

+-----------------------------------------------------------------------------+
| **From the search pipeline database**                                       |
+------------------------+----------------------------------------------------+
| ``coinc_event_id``     | event identifier                                   |
+------------------------+----------------------------------------------------+
| ``simulation_id``      | injection identifier                               |
+------------------------+----------------------------------------------------+
| ``far``                | false alarm rate                                   |
+------------------------+----------------------------------------------------+
| ``snr``                | signal to noise ratio                              |
+------------------------+----------------------------------------------------+
| **Injection finding**                                                       |
+------------------------+----------------------------------------------------+
| ``searched_area``      | area of 2D credible region containing the true sky |
|                        | location                                           |
+------------------------+----------------------------------------------------+
| ``searched_prob``      | probability in that 2D credible region             |
+------------------------+----------------------------------------------------+
| ``searched_prob_dist`` | marginal distance CDF at the true distance         |
+------------------------+----------------------------------------------------+
| ``searched_vol``       | volume of 3D credible region containing the true   |
|                        | position                                           |
+------------------------+----------------------------------------------------+
| ``searched_prob_vol``  | probability contained in that volume               |
+------------------------+----------------------------------------------------+
| ``offset``             | angular separation between the maximum             |
|                        | *a posteriori* position and the true sky position  |
+------------------------+----------------------------------------------------+
| **Additional metadata from the sky maps**                                   |
+------------------------+----------------------------------------------------+
| ``runtime``            | wall clock run time to generate sky map            |
+------------------------+----------------------------------------------------+
| ``distmean``           | mean *a posteriori* distance                       |
+------------------------+----------------------------------------------------+
| ``diststd``            | *a posteriori* standard deviation of distance      |
+------------------------+----------------------------------------------------+
| ``log_bci``            | natural log Bayes factor, coherent vs. incoherent  |
+------------------------+----------------------------------------------------+
| ``log_bsn``            | natural log Bayes factor, signal vs. noise         |
+------------------------+----------------------------------------------------+
| **Credible levels** (if ``--area`` or ``--contour`` options present)        |
+------------------------+----------------------------------------------------+
| ``area(P)``            | area of the *P* percent 2D credible region         |
+------------------------+----------------------------------------------------+
| ``prob(A)``            | probability contained within the 2D credible level |
|                        | of area *A*                                        |
+------------------------+----------------------------------------------------+
| ``dist(P)``            | distance for a cumulative marginal probability of  |
|                        | *P* percent                                        |
+------------------------+----------------------------------------------------+
| ``vol(P)``             | volume of the *P* percent 3D credible region       |
+------------------------+----------------------------------------------------+
| **Modes** (if ``--modes`` option is present)                                |
+------------------------+----------------------------------------------------+
| ``searched_modes``     | number of simply connected figures in the 2D       |
|                        | credible region containing the true sky location   |
+------------------------+----------------------------------------------------+
| ``modes(P)``           | number of simply connected figures in the *P*      |
|                        | percent 2D credible region                         |
+------------------------+----------------------------------------------------+

"""

from functools import partial
import sys

from astropy.coordinates import SkyCoord
from astropy import units as u
import numpy as np

from . import ArgumentParser, FileType, SQLiteType
from ..io import fits
from ..postprocess import crossmatch


def parser():
    parser = ArgumentParser()
    parser.add_argument(
        '-o', '--output', metavar='OUT.dat', type=FileType('w'), default='-',
        help='Name of output file')
    parser.add_argument(
        '-j', '--jobs', type=int, default=1, const=None, nargs='?',
        help='Number of threads')
    parser.add_argument(
        '-p', '--contour', default=[], nargs='+', type=float,
        metavar='PERCENT',
        help='Report the area of the smallest contour and the number of modes '
        'containing this much probability.')
    parser.add_argument(
        '-a', '--area', default=[], nargs='+', type=float, metavar='DEG2',
        help='Report the largest probability contained within any region '
        'of this area in square degrees. Can be repeated multiple times.')
    parser.add_argument(
        '--modes', action='store_true',
        help='Compute number of disjoint modes')
    parser.add_argument(
        '-d', '--database', type=SQLiteType('r'), metavar='DB.sqlite',
        help='Input SQLite database from search pipeline')
    parser.add_argument(
        'fitsfilenames', metavar='GLOB.fits[.gz]', nargs='+', action='glob',
        help='Input FITS filenames and/or globs')
    parser.add_argument(
        '--cosmology', action='store_true',
        help='Report volume localizations as comoving volumes.')
    return parser


def process(fitsfilename, db, contours, modes, areas, cosmology):
    sky_map = fits.read_sky_map(fitsfilename, moc=True)

    coinc_event_id = sky_map.meta.get('objid')
    try:
        runtime = sky_map.meta['runtime']
    except KeyError:
        runtime = float('nan')

    contour_pvalues = 0.01 * np.asarray(contours)

    if db is None:
        simulation_id = true_ra = true_dec = true_dist = far = snr = None
    else:
        row = db.execute(
            """
            SELECT DISTINCT sim.simulation_id AS simulation_id,
            sim.longitude AS ra, sim.latitude AS dec, sim.distance AS distance,
            ci.combined_far AS far, ci.snr AS snr
            FROM coinc_event_map AS cem1 INNER JOIN coinc_event_map AS cem2
            ON (cem1.coinc_event_id = cem2.coinc_event_id)
            INNER JOIN sim_inspiral AS sim
            ON (cem1.event_id = sim.simulation_id)
            INNER JOIN coinc_inspiral AS ci
            ON (cem2.event_id = ci.coinc_event_id)
            WHERE cem1.table_name = 'sim_inspiral'
            AND cem2.table_name = 'coinc_event' AND cem2.event_id = ?
            """, (coinc_event_id,)).fetchone()
        if row is None:
            return None
        simulation_id, true_ra, true_dec, true_dist, far, snr = row

    if true_ra is None or true_dec is None:
        true_coord = None
    elif true_dist is None:
        true_coord = SkyCoord(true_ra * u.rad, true_dec * u.rad)
    else:
        true_coord = SkyCoord(true_ra * u.rad, true_dec * u.rad,
                              true_dist * u.Mpc)

    (
        searched_area, searched_prob, offset, searched_modes, contour_areas,
        area_probs, contour_modes, searched_prob_dist, contour_dists,
        searched_vol, searched_prob_vol, contour_vols, probdensity,
        probdensity_vol
    ) = crossmatch(
        sky_map, true_coord,
        contours=contour_pvalues, areas=areas, modes=modes, cosmology=cosmology
    )

    if snr is None:
        snr = np.nan
    if far is None:
        far = np.nan
    distmean = sky_map.meta.get('distmean', np.nan)
    diststd = sky_map.meta.get('diststd', np.nan)
    log_bci = sky_map.meta.get('log_bci', np.nan)
    log_bsn = sky_map.meta.get('log_bsn', np.nan)

    ret = [coinc_event_id]
    if db is not None:
        ret += [
            simulation_id, far, snr, searched_area,
            searched_prob, searched_prob_dist, searched_vol, searched_prob_vol,
            offset]
    ret += [runtime, distmean, diststd, log_bci, log_bsn]
    ret += contour_areas + area_probs + contour_dists + contour_vols
    if modes:
        if db is not None:
            ret += [searched_modes]
        ret += contour_modes
    return ret


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    from ..util.progress import progress_map

    if args is None:
        print('#', *sys.argv, file=opts.output)
    else:
        print('#', p.prog, *args, file=opts.output)

    colnames = ['coinc_event_id']
    if opts.database is not None:
        colnames += ['simulation_id', 'far', 'snr', 'searched_area',
                     'searched_prob', 'searched_prob_dist', 'searched_vol',
                     'searched_prob_vol', 'offset']
    colnames += ['runtime', 'distmean', 'diststd', 'log_bci', 'log_bsn']
    colnames += ['area({0:g})'.format(_) for _ in opts.contour]
    colnames += ['prob({0:g})'.format(_) for _ in opts.area]
    colnames += ['dist({0:g})'.format(_) for _ in opts.contour]
    colnames += ['vol({0:g})'.format(_) for _ in opts.contour]
    if opts.modes:
        if opts.database is not None:
            colnames += ['searched_modes']
        colnames += ["modes({0:g})".format(p) for p in opts.contour]
    print(*colnames, sep="\t", file=opts.output)

    func = partial(process, db=opts.database, contours=opts.contour,
                   modes=opts.modes, areas=opts.area, cosmology=opts.cosmology)
    for record in progress_map(func, opts.fitsfilenames, jobs=opts.jobs):
        if record is not None:
            print(*record, sep="\t", file=opts.output)

ligo/skymap/tool/ligo_skymap_unflatten.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Convert a HEALPix FITS file to multi-resolution UNIQ indexing from the more
common IMPLICIT indexing.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument('input', metavar='INPUT.fits',
                        type=FileType('rb'), help='Input FITS file')
    parser.add_argument('output', metavar='OUTPUT.fits[.gz]',
                        type=FileType('wb'), help='Output FITS file')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    import warnings
    from astropy.io import fits
    from ..io import read_sky_map, write_sky_map

    hdus = fits.open(args.input)
    ordering = hdus[1].header['ORDERING']
    expected_orderings = {'NESTED', 'RING'}
    if ordering not in expected_orderings:
        msg = 'Expected the FITS file {} to have ordering {}, but it is {}'
        warnings.warn(msg.format(
            args.input.name, ' or '.join(expected_orderings), ordering))
    table = read_sky_map(hdus, moc=True)
    write_sky_map(args.output.name, table)

ligo/skymap/tool/matplotlib.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
#
# Copyright (C) 2013-2023  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Functions that support the command line interface for plotting tools."""

import argparse
import os
import sys

import matplotlib

from ..plot import cmap  # noqa
from . import FileType, HelpChoicesAction, type_with_sideeffect, version_string

# Set no-op Matplotlib backend to defer importing anything that requires a GUI
# until we have determined that it is necessary based on the command line
# arguments.
if 'matplotlib.pyplot' in sys.modules:
    from matplotlib import pyplot as plt
    plt.switch_backend('Template')
else:
    matplotlib.use('Template', warn=False, force=True)
    from matplotlib import pyplot as plt

__all__ = ('get_figure_parser',)


class MatplotlibFigureType(FileType):

    def __init__(self):
        super().__init__('wb')

    @staticmethod
    def __show():
        from matplotlib import pyplot as plt
        return plt.show()

    @staticmethod
    def get_savefig_metadata(format):
        program, _ = os.path.splitext(os.path.basename(sys.argv[0]))
        cmdline = ' '.join([program] + sys.argv[1:])
        metadata = {'Title': cmdline}
        if format == 'png':
            metadata['Software'] = version_string
        elif format in {'pdf', 'ps', 'eps'}:
            metadata['Creator'] = version_string
        return metadata

    def __save(self):
        from matplotlib import pyplot as plt
        _, ext = os.path.splitext(self.string)
        format = ext.lower().lstrip('.')
        metadata = self.get_savefig_metadata(format)
        return plt.savefig(self.string, metadata=metadata)

    def __call__(self, string):
        from matplotlib import pyplot as plt
        if string == '-':
            plt.switch_backend(matplotlib.rcParamsOrig['backend'])
            return self.__show
        else:
            with super().__call__(string):
                pass
            plt.switch_backend('agg')
            self.string = string
            return self.__save


@type_with_sideeffect(str)
def colormap(value):
    from matplotlib import rcParams
    rcParams['image.cmap'] = value


@type_with_sideeffect(float)
def figwidth(value):
    from matplotlib import rcParams
    rcParams['figure.figsize'][0] = float(value)


@type_with_sideeffect(float)
def figheight(value):
    from matplotlib import rcParams
    rcParams['figure.figsize'][1] = float(value)


@type_with_sideeffect(int)
def dpi(value):
    from matplotlib import rcParams
    rcParams['figure.dpi'] = rcParams['savefig.dpi'] = float(value)


@type_with_sideeffect(int)
def transparent(value):
    from matplotlib import rcParams
    rcParams['savefig.transparent'] = bool(value)


def get_figure_parser():
    parser = argparse.ArgumentParser(add_help=False)
    group = parser.add_argument_group(
        'figure options', 'Options that affect figure output format')
    group.add_argument(
        '-o', '--output', metavar='FILE.{pdf,png}',
        default='-', type=MatplotlibFigureType(),
        help='output file, or - to plot to screen')
    group.add_argument(
        '--colormap', default='cylon', choices=plt.colormaps(), type=colormap,
        metavar='CMAP', help='matplotlib colormap')
    group.add_argument(
        '--help-colormap', action=HelpChoicesAction, choices=plt.colormaps())
    group.add_argument(
        '--figure-width', metavar='INCHES', type=figwidth, default='8',
        help='width of figure in inches')
    group.add_argument(
        '--figure-height', metavar='INCHES', type=figheight, default='6',
        help='height of figure in inches')
    group.add_argument(
        '--dpi', metavar='PIXELS', type=dpi, default=300,
        help='resolution of figure in dots per inch')
    group.add_argument(
        '--transparent', const='1', default='0', nargs='?', type=transparent,
        help='Save image with transparent background')
    return parser

ligo/skymap/util/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/util/file.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
#
# Copyright (C) 2018-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""File tools."""
import errno
import os
import shutil
import tempfile


def rename(src, dst):
    """Like `os.rename`, but works across different devices because it
    catches and handles ``EXDEV`` (``Invalid cross-device link``) errors.
    """
    try:
        os.rename(src, dst)
    except OSError as e:
        if e.errno == errno.EXDEV:
            dir, suffix = os.path.split(dst)
            tmpfid, tmpdst = tempfile.mkstemp(dir=dir, suffix=suffix)
            try:
                os.close(tmpfid)
                shutil.copy2(src, tmpdst)
                os.rename(tmpdst, dst)
            except:  # noqa: E722
                os.remove(tmpdst)
                raise
        else:
            raise


def rm_f(filename):
    """Remove a file, or be silent if the file does not exist, like ``rm -f``.

    Examples
    --------
    >>> with tempfile.TemporaryDirectory() as d:
    ...     rm_f('test')
    ...     with open('test', 'w') as f:
    ...         print('Hello world', file=f)
    ...     rm_f('test')

    """
    try:
        os.remove(filename)
    except FileNotFoundError:
        pass

ligo/skymap/util/ilwd.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
#
# Copyright (C) 2020-2022  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for adapting LIGO-LW row ID formats."""
import re

from ligo.lw.ligolw import Param
from ligo.lw.lsctables import TableByName
from ligo.lw.table import Column, TableStream
from ligo.lw.types import FormatFunc, FromPyType

__all__ = ('use_in',)

IDTypes = {'ilwd:char', 'ilwd:char_u'}

ROWID_PYTYPE = int
ROWID_TYPE = FromPyType[ROWID_PYTYPE]
ROWID_FORMATFUNC = FormatFunc[ROWID_TYPE]


_ilwd_regex = re.compile(r'\s*\w+:\w+:(\d+)\s*')


def ilwd_to_int(ilwd):
    match = _ilwd_regex.fullmatch(ilwd)
    if not match:
        raise ValueError(f'"{ilwd}" is not formatt like an ilwd')
    return int(match[1])


def use_in(ContentHandler):
    """Convert from old-style to new-style row IDs on the fly.

    This is loosely adapted from :func:`ligo.lw.utils.ilwd.strip_ilwdchar`.

    Notes
    -----
    When building a ContentHandler, this must be the _outermost_ decorator,
    outside of :func:`ligo.lw.lsctables.use_in`, :func:`ligo.lw.param.use_in`,
    or :func:`ligo.lw.table.use_in`.

    Examples
    --------
    >>> from importlib.resources import as_file, files
    >>> from ligo.lw import array, ligolw, lsctables, param, table, utils
    >>> from ligo.skymap.util import ilwd
    >>> @ilwd.use_in
    ... @lsctables.use_in
    ... @param.use_in
    ... @table.use_in
    ... class ContentHandler(ligolw.LIGOLWContentHandler):
    ...     pass
    >>> with as_file(files('ligo.skymap.io.tests.data').joinpath(
    ...         'G197392_coinc.xml.gz')) as f:
    ...     xmldoc = utils.load_filename(f, contenthandler=ContentHandler)
    >>> table = lsctables.SnglInspiralTable.get_table(xmldoc)
    >>> table[0].process_id
    0

    """

    def endElementNS(self, uri_localname, qname,
                     __orig_endElementNS=ContentHandler.endElementNS):
        """Convert values of <Param> elements from ilwdchar to int."""
        if isinstance(self.current, Param) and self.current.Type in IDTypes:
            new_value = ilwd_to_int(self.current.pcdata)
            self.current.Type = ROWID_TYPE
            self.current.pcdata = ROWID_FORMATFUNC(new_value)
        __orig_endElementNS(self, uri_localname, qname)

    remapped = {}

    def startColumn(self, parent, attrs,
                    __orig_startColumn=ContentHandler.startColumn):
        """Convert types in <Column> elements from ilwdchar to int.

        Notes
        -----
        This method is adapted from
        :func:`ligo.lw.utils.ilwd.strip_ilwdchar`.

        """
        result = __orig_startColumn(self, parent, attrs)

        # If this is an ilwdchar column, then create a function to convert its
        # rows' values for use in the startStream method below.
        if result.Type in IDTypes:
            remapped[(id(parent), result.Name)] = ilwd_to_int
            result.Type = ROWID_TYPE

        # If this is an ilwdchar column, then normalize the column name.
        if parent.Name in TableByName:
            validcolumns = TableByName[parent.Name].validcolumns
            if result.Name not in validcolumns:
                stripped_column_to_valid_column = {
                    Column.ColumnName(name): name for name in validcolumns}
                if result.Name in stripped_column_to_valid_column:
                    result.setAttribute(
                        'Name', stripped_column_to_valid_column[result.Name])

        return result

    def startStream(self, parent, attrs,
                    __orig_startStream=ContentHandler.startStream):
        """Convert values in table <Stream> elements from ilwdchar to int.

        Notes
        -----
        This method is adapted from
        :meth:`ligo.lw.table.TableStream.config`.

        """
        result = __orig_startStream(self, parent, attrs)
        if isinstance(result, TableStream):
            loadcolumns = set(parent.columnnames)
            if parent.loadcolumns is not None:
                # FIXME:  convert loadcolumns attributes to sets to
                # avoid the conversion.
                loadcolumns &= set(parent.loadcolumns)
            result._tokenizer.set_types([
                (remapped.pop((id(parent), colname), pytype)
                 if colname in loadcolumns else None)
                for pytype, colname
                in zip(parent.columnpytypes, parent.columnnames)])
        return result

    ContentHandler.endElementNS = endElementNS
    ContentHandler.startColumn = startColumn
    ContentHandler.startStream = startStream

    return ContentHandler

ligo/skymap/util/math.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
#
# Copyright (C) 2024  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Math utilities."""

__all__ = ('derivative',)


def derivative(func, x0, dx=1.0):
    return 0.5 * (func(x0 + dx) - func(x0 - dx)) / dx

ligo/skymap/util/numpy.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
#
# Copyright (C) 2018-2014  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import functools
import numpy as np
from numpy.core.umath import _add_newdoc_ufunc

__all__ = ('add_newdoc_ufunc', 'require_contiguous_aligned')


def add_newdoc_ufunc(func, doc):  # pragma: no cover
    """Set the docstring for a Numpy ufunc.

    The function :func:`numpy.core.umath._add_newdoc_ufunc` can only change a
    ufunc's docstring if it is `NULL`. This workaround avoids an exception when
    the user tries to `reload()` this module.

    Notes
    -----
    :func:`numpy.core.umath._add_newdoc_ufunc` is not part of Numpy's public
    API, but according to upstream developers it is unlikely to go away any
    time soon.

    See https://github.com/numpy/numpy/issues/26233.
    """
    try:
        _add_newdoc_ufunc(func, doc)
    except ValueError as e:
        msg = 'Cannot change docstring of ufunc with non-NULL docstring'
        if e.args[0] == msg:
            pass


def require_contiguous_aligned(func):
    """Wrap a Numpy ufunc to guarantee that all of its inputs are
    C-contiguous arrays.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        n = func.nin
        args = [arg if i >= n or np.isscalar(arg)
                else np.require(arg, requirements={'CONTIGUOUS', 'ALIGNED'})
                for i, arg in enumerate(args)]
        return func(*args, **kwargs)
    return wrapper

ligo/skymap/util/progress.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for progress bars"""

try:
    from billiard import Pool
except ImportError:
    from multiprocessing import Pool
from heapq import heappop, heappush
from operator import length_hint

from tqdm.auto import tqdm

from .. import omp

__all__ = ('progress_map',)


class WrappedFunc:

    def __init__(self, func):
        self.func = func

    def __call__(self, i_args):
        i, args = i_args
        return i, self.func(*args)


def _get_total_estimate(*iterables):
    """Estimate total loop iterations for mapping over multiple iterables."""
    estimates = (length_hint(iterable, -1) for iterable in iterables)
    valid_estimates = (estimate for estimate in estimates if estimate != -1)
    return min(valid_estimates, default=None)


def _results_in_order(completed):
    """Put results back into order and yield them as quickly as they arrive."""
    heap = []
    current = 0
    for i_result in completed:
        i, result = i_result
        if i == current:
            yield result
            current += 1
            while heap and heap[0][0] == current:
                _, result = heappop(heap)
                yield result
                current += 1
        else:
            heappush(heap, i_result)
    assert not heap, 'The heap must be empty'


def _init_process():
    """Disable OpenMP when using multiprocessing."""
    omp.num_threads = 1


def progress_map(func, *iterables, jobs=1, **kwargs):
    """Map a function across iterables of arguments.

    This is comparable to :meth:`astropy.utils.console.ProgressBar.map`, except
    that it is implemented using :mod:`tqdm` and so provides more detailed and
    accurate progress information.
    """
    total = _get_total_estimate(*iterables)
    if jobs == 1:
        yield from tqdm(map(func, *iterables), total=total, **kwargs)
    else:
        with Pool(jobs, _init_process) as pool:
            yield from _results_in_order(
                tqdm(
                    pool.imap_unordered(
                        WrappedFunc(func),
                        enumerate(zip(*iterables))
                    ),
                    total=total, **kwargs
                )
            )

ligo/skymap/util/sqlite.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for reading and writing SQLite databases."""
import copyreg
import sqlite3
_open = open


def _open_a(string):
    return sqlite3.connect(string, check_same_thread=False)


def _open_r(string):
    return sqlite3.connect('file:{}?mode=ro'.format(string),
                           check_same_thread=False, uri=True)


def _open_w(string):
    with _open(string, 'wb'):
        pass
    return sqlite3.connect(string, check_same_thread=False)


_openers = {'a': _open_a, 'r': _open_r, 'w': _open_w}


def open(string, mode):
    """Open an SQLite database with an `open`-style mode flag.

    Parameters
    ----------
    string : str
        Path of the SQLite database file
    mode : {'r', 'w', 'a'}
        Access mode: read only, clobber and overwrite, or modify in place.

    Returns
    -------
    connection : `sqlite3.Connection`

    Raises
    ------
    ValueError
        If the filename is invalid (e.g. ``/dev/stdin``), or if the requested
        mode is invalid
    OSError
        If the database could not be opened in the specified mode

    Examples
    --------
    >>> import tempfile
    >>> import os
    >>> with tempfile.TemporaryDirectory() as d:
    ...     open(os.path.join(d, 'test.sqlite'), 'w')
    ...
    <sqlite3.Connection object at 0x...>

    >>> with tempfile.TemporaryDirectory() as d:
    ...     open(os.path.join(d, 'test.sqlite'), 'r')
    ...
    Traceback (most recent call last):
      ...
    OSError: Failed to open database ...

    >>> open('/dev/stdin', 'r')
    Traceback (most recent call last):
      ...
    ValueError: Cannot open stdin/stdout as an SQLite database

    >>> open('test.sqlite', 'x')
    Traceback (most recent call last):
      ...
    ValueError: Invalid mode "x". Must be one of "arw".

    """
    if string in {'-', '/dev/stdin', '/dev/stdout'}:
        raise ValueError('Cannot open stdin/stdout as an SQLite database')
    try:
        opener = _openers[mode]
    except KeyError:
        raise ValueError('Invalid mode "{}". Must be one of "{}".'.format(
            mode, ''.join(sorted(_openers.keys()))))
    try:
        return opener(string)
    except (OSError, sqlite3.Error) as e:
        raise OSError('Failed to open database {}: {}'.format(string, e))


def get_filename(connection):
    r"""Get the name of the file associated with an SQLite connection.

    Parameters
    ----------
    connection : `sqlite3.Connection`
        The database connection

    Returns
    -------
    str
        The name of the file that contains the SQLite database

    Raises
    ------
    RuntimeError
        If more than one database is attached to the connection

    Examples
    --------
    >>> import tempfile
    >>> import os
    >>> with tempfile.TemporaryDirectory() as d:
    ...     with sqlite3.connect(os.path.join(d, 'test.sqlite')) as db:
    ...         print(get_filename(db))
    ...
    /.../test.sqlite

    >>> with tempfile.TemporaryDirectory() as d:
    ...     with sqlite3.connect(os.path.join(d, 'test1.sqlite')) as db1, \
    ...          sqlite3.connect(os.path.join(d, 'test2.sqlite')) as db2:
    ...         filename = get_filename(db1)
    ...         db2.execute('ATTACH DATABASE "{}" AS db2'.format(filename))
    ...         print(get_filename(db2))
    ...
    Traceback (most recent call last):
      ...
    RuntimeError: Expected exactly one attached database

    """
    result = connection.execute('pragma database_list').fetchall()
    try:
        (_, _, filename), = result
    except ValueError:
        raise RuntimeError('Expected exactly one attached database')
    return filename


def pickle_sqlite3_connection(obj):
    return _open_a, (get_filename(obj),)


copyreg.pickle(sqlite3.Connection, pickle_sqlite3_connection)

ligo/skymap/util/stopwatch.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
#
# Copyright (C) 2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Performance measurement utilities."""
from resource import getrusage, RUSAGE_SELF
from time import perf_counter

import numpy as np


class StopwatchTimes:

    def __init__(self, real=0, user=0, sys=0):
        self.real = real
        self.user = user
        self.sys = sys

    def __iadd__(self, other):
        self.real += other.real
        self.user += other.user
        self.sys += other.sys
        return self

    def __isub__(self, other):
        self.real -= other.real
        self.user -= other.user
        self.sys -= other.sys
        return self

    def __add__(self, other):
        return StopwatchTimes(self.real + other.real,
                              self.user + other.user,
                              self.sys + other.sys)

    def __sub__(self, other):
        return StopwatchTimes(self.real - other.real,
                              self.user - other.user,
                              self.sys - other.sys)

    def __repr__(self):
        return f'{self.__class__.__name__}(real={self.real!r}, user={self.user!r}, sys={self.sys!r})'  # noqa: E501

    def __str__(self):
        real, user, sys = (np.format_float_positional(val, 3, unique=False)
                           for val in (self.real, self.user, self.sys))
        return f'real={real}s, user={user}s, sys={sys}s'

    @classmethod
    def now(cls):
        rusage = getrusage(RUSAGE_SELF)
        return cls(perf_counter(), rusage.ru_utime, rusage.ru_stime)

    def reset(self):
        self.real = self.user = self.sys = 0


class Stopwatch(StopwatchTimes):
    """A code profiling utility that mimics the interface of a stopwatch."""

    def __init__(self):
        super().__init__()
        self._base = None

    def start(self):
        self._base = StopwatchTimes.now()

    def stop(self):
        self += StopwatchTimes.now() - self._base
        self._base = None
        return self

    def lap(self):
        delta = StopwatchTimes.now() - self._base
        self += delta
        return delta

src/bayestar_distance.c

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
/*
 * Copyright (C) 2015-2017  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */


#include "bayestar_distance.h"

#include <gsl/gsl_cblas.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_math.h>
#include <gsl/gsl_roots.h>
#include <gsl/gsl_sf_erf.h>
#include <gsl/gsl_sf_exp.h>
#include <gsl/gsl_cdf.h>
#include <gsl/gsl_statistics.h>

#include <chealpix.h>


double bayestar_distance_conditional_pdf(
    double r, double mu, double sigma, double norm)
{
    if (!isfinite(mu))
        return 0;

    const double x = -0.5 * gsl_pow_2((r - mu) / sigma);
    const double y = norm * gsl_pow_2(r) / (sqrt(2 * M_PI) * sigma);
    return gsl_sf_exp_mult(x, y);
}


static double ugaussian_integral(double x1, double x2)
{
    if (GSL_SIGN(x1) != GSL_SIGN(x2))
    {
        return gsl_cdf_ugaussian_P(x2) - gsl_cdf_ugaussian_P(x1);
    } else if (x1 > 0) {
        const double logerfc1 = gsl_sf_log_erfc(x1 * M_SQRT1_2);
        const double logerfc2 = gsl_sf_log_erfc(x2 * M_SQRT1_2);
        return 0.5 * (exp(logerfc1) - exp(logerfc2));
    } else {
        const double logerfc1 = gsl_sf_log_erfc(-x1 * M_SQRT1_2);
        const double logerfc2 = gsl_sf_log_erfc(-x2 * M_SQRT1_2);
        return 0.5 * (exp(logerfc2) - exp(logerfc1));
    }
}


double bayestar_distance_conditional_cdf(
    double r, double mu, double sigma, double norm)
{
    if (!isfinite(mu))
        return 0;

    const double mu2 = gsl_pow_2(mu);
    const double sigma2 = gsl_pow_2(sigma);
    const double arg1 = -mu / sigma;
    const double arg2 = (r - mu) / sigma;

    return (
        (mu2 + sigma2) * ugaussian_integral(arg1, arg2)
        + sigma / sqrt(2 * M_PI) * (gsl_sf_exp_mult(-0.5 * gsl_pow_2(arg1), mu)
        - gsl_sf_exp_mult(-0.5 * gsl_pow_2(arg2), r + mu))
    ) * norm;
}


typedef struct {
    double p, mu, norm;
} conditional_ppf_params;


static void conditional_ppf_fdf(double r, void *params, double *f, double *df)
{
    const conditional_ppf_params *p = (conditional_ppf_params *)params;
    const double _f = bayestar_distance_conditional_cdf(r, p->mu, 1, p->norm);
    const double _df = bayestar_distance_conditional_pdf(r, p->mu, 1, p->norm);
    if (p->p > 0.5)
    {
        *f = log(1 - _f) - log(1 - p->p);
        *df = -_df / (1 - _f);
    } else {
        *f = log(_f) - log(p->p);
        *df = _df / _f;
    }
}


static double conditional_ppf_f(double r, void *params)
{
    double f, df;
    conditional_ppf_fdf(r, params, &f, &df);
    return f;
}


static double conditional_ppf_df(double r, void *params)
{
    double f, df;
    conditional_ppf_fdf(r, params, &f, &df);
    return df;
}


static double conditional_ppf_initial_guess(double p, double mu)
{
    /* Initial guess: ignore r^2 term;
     * distribution becomes truncated Gaussian */
    const double z = gsl_cdf_ugaussian_Pinv(p + (1 - p) * gsl_cdf_ugaussian_P(-mu)) + mu;

    if (z > 0)
        return z;
    else if (mu > 0)
        return mu;  /* Fallback 1: mean */
    else
        return 0.5;  /* Fallback 2: constant value */
}


double bayestar_distance_conditional_ppf(
    double p, double mu, double sigma, double norm)
{
    if (p <= 0)
        return 0;
    else if (p >= 1)
        return GSL_POSINF;
    else if (!(isfinite(p) && isfinite(mu)
            && isfinite(sigma) && isfinite(norm)))
        return GSL_NAN;

    /* Convert to standard distribution with sigma = 1. */
    mu /= sigma;
    norm *= gsl_pow_2(sigma);

    /* Set up variables for tracking progress toward the solution. */
    static const int max_iter = 50;
    conditional_ppf_params params = {p, mu, norm};
    int iter = 0;
    double z = conditional_ppf_initial_guess(p, mu);
    int status;

    /* Set up solver (on stack). */
    const gsl_root_fdfsolver_type *algo = gsl_root_fdfsolver_steffenson;
    char state[algo->size];
    gsl_root_fdfsolver solver = {algo, NULL, 0, state};
    gsl_function_fdf fun = {
        conditional_ppf_f, conditional_ppf_df, conditional_ppf_fdf, &params};
    gsl_root_fdfsolver_set(&solver, &fun, z);

    do
    {
        const double zold = z;
        status = gsl_root_fdfsolver_iterate(&solver);
        z = gsl_root_fdfsolver_root(&solver);
        status = gsl_root_test_delta (z, zold, 0, GSL_SQRT_DBL_EPSILON);
        iter++;
    } while (status == GSL_CONTINUE && iter < max_iter);
    /* FIXME: do something with status? */

    /* Rescale to original value of sigma. */
    z *= sigma;

    return z;
}


static void integrals(
    double z,
    double *x2, double *x3, double *x4,
    double *dx2, double *dx3, double *dx4)
{
    const double H = gsl_sf_hazard(- z);
    const double Hp = - H * (z + H);
    const double z2 = gsl_pow_2(z);
    *x2 = z2 + 1 + z * H;
    *x3 = z * (z2 + 3) + (z2 + 2) * H;
    *x4 = z2 * (z2 + 6) + 3 + z * (z2 + 5) * H;
    *dx2 = 2 * z + H + z * Hp;
    *dx3 = 3 * (z2 + 1) + 2 * z * H + (z2 + 2) * Hp;
    *dx4 = 4 * z * (z2 + 3) + (3 * z2 + 5) * H + z * (z2 + 5) * Hp;
}


static void moments_to_parameters_fdf(
    double z, void *params, double *fval, double *dfval)
{
    const double mean_std = *(double *)params;
    const double target = 1 / gsl_pow_2(mean_std) + 1;
    double x2, x3, x4, dx2, dx3, dx4;
    integrals(z, &x2, &x3, &x4, &dx2, &dx3, &dx4);
    *fval = target * gsl_pow_2(x3) - x4 * x2;
    *dfval = target * 2 * x3 * dx3 - x4 * dx2 - dx4 * x2;
}


static double moments_to_parameters_f(double z, void *params)
{
    double fval, dfval;
    moments_to_parameters_fdf(z, params, &fval, &dfval);
    return fval;
}


static double moments_to_parameters_df(double z, void *params)
{
    double fval, dfval;
    moments_to_parameters_fdf(z, params, &fval, &dfval);
    return dfval;
}


static int solve_z(double mean_std, double *result)
{
    /* Set up variables for tracking progress toward the solution. */
    static const int max_iter = 50;
    int iter = 0;
    double z = mean_std;
    int status;

    /* Set up solver (on stack). */
    const gsl_root_fdfsolver_type *algo = gsl_root_fdfsolver_steffenson;
    char state[algo->size];
    gsl_root_fdfsolver solver = {algo, NULL, 0, state};
    gsl_function_fdf fun = {
        moments_to_parameters_f,
        moments_to_parameters_df,
        moments_to_parameters_fdf,
        &mean_std};
    gsl_root_fdfsolver_set(&solver, &fun, z);

    do
    {
        const double zold = z;
        status = gsl_root_fdfsolver_iterate(&solver);
        z = gsl_root_fdfsolver_root(&solver);
        status = gsl_root_test_delta (z, zold, 0, GSL_SQRT_DBL_EPSILON);
        iter++;
    } while (status == GSL_CONTINUE && iter < max_iter);

    *result = z;
    return status;
}


int bayestar_distance_moments_to_parameters(
    double mean, double std, double *mu, double *sigma, double *norm)
{
    /* Set up function to solve. */
    double mean_std = mean / std;
    /* Minimum value of (mean/std) for a quadratically weighted
     * normal distribution. The limit of (mean/std) as (mu/sigma) goes to -inf
     * is sqrt(3). We limit (mean/std) to a little bit more than sqrt(3),
     * because as (mu/sigma) becomes more and more negative the normalization
     * has to get very large.
     */
    static const double min_mean_std = M_SQRT3 + 1e-2;
    int status;

    if (gsl_finite(mean_std) && mean_std >= min_mean_std)
    {
        double z, x2, x3, x4, dx2, dx3, dx4;
        status = solve_z(mean_std, &z);
        integrals(z, &x2, &x3, &x4, &dx2, &dx3, &dx4);
        *sigma = mean * x2 / x3;
        *mu = *sigma * z;
        *norm = 1 / (gsl_pow_2(*sigma) * x2 * gsl_sf_erf_Q(-z));
    } else {
        status = GSL_SUCCESS;
        *mu = INFINITY;
        *sigma = 1;
        *norm = 0;
    }

    return status;
}


void bayestar_distance_parameters_to_moments(
    double mu, double sigma, double *mean, double *std, double *norm)
{
    if (gsl_finite(mu / sigma))
    {
        const double z = mu / sigma;
        double x2, x3, x4, dx2, dx3, dx4;

        integrals(z, &x2, &x3, &x4, &dx2, &dx3, &dx4);

        *mean = sigma * x3 / x2;
        *std = *mean * sqrt(x4 * x2 / gsl_pow_2(x3) - 1);
        *norm = 1 / (gsl_pow_2(sigma) * x2 * gsl_sf_erf_Q(-z));
    } else {
        *mean = INFINITY;
        *std = 1;
        *norm = 0;
    }
}


static double bayestar_volume_render_inner(
    double x, double y, double z, int axis0, int axis1, int axis2,
    const double *R, long long nside, int nest, const double *prob, const
    double *mu, const double *sigma, const double *norm)
{
    double ret;
    double xyz[3];
    xyz[axis0] = x;
    xyz[axis1] = y;
    xyz[axis2] = z;

   /* Transform from screen-aligned cube to celestial coordinates before
    * looking up pixel indices. */
    double vec[3];
    cblas_dgemv(
        CblasRowMajor, CblasNoTrans, 3, 3, 1, R, 3, xyz, 1, 0, vec, 1);
    int64_t ipix;
    if (nest)
        vec2pix_nest64(nside, vec, &ipix);
    else
        vec2pix_ring64(nside, vec, &ipix);
    double r = sqrt(gsl_pow_2(x) + gsl_pow_2(y) + gsl_pow_2(z));

    if (isfinite(mu[ipix]))
        ret = gsl_sf_exp_mult(
            -0.5 * gsl_pow_2((r - mu[ipix]) / sigma[ipix]),
            prob[ipix] * norm[ipix] / sigma[ipix]);
    else
        ret = 0;
    return ret;
}


double bayestar_volume_render(
    double x, double y, double max_distance, int axis0, int axis1,
    const double *R, long long nside, int nest,
    const double *prob, const double *mu,
    const double *sigma, const double *norm)
{
    /* Determine which axis to integrate over
     * (the one that isn't in the args) */
    int axis2;
    int axes[] = {0, 0, 0};
    axes[axis0] = 1;
    axes[axis1] = 1;
    for (axis2 = 0; axes[axis2]; axis2++)
        ; /* loop body intentionally no-op */

    /* Construct grid in theta, the elevation angle from the
     * spatial origin to the plane of the screen. */

    /* Transverse distance from origin to point on screen */
    const double a = sqrt(gsl_pow_2(x) + gsl_pow_2(y));

    /* Maximum value of theta (at edge of screen-aligned cube) */
    const double theta_max = atan2(max_distance, a);
    const double dtheta = 0.5 * M_PI / nside / 4;

    double ret = 0;

    /* Far from the center of the image, we integrate in theta so that we
     * step through HEALPix pixels at an approximately uniform rate.
     *
     * In the central 10% of the image, we integrate in z to avoid the
     * coordinate singularity in theta.
     */
    if (a >= 5e-2 * max_distance)
    {
        /* Construct regular grid from -theta_max to +theta_max */
        for (double theta = -theta_max; theta <= theta_max; theta += dtheta)
        {
            /* Differential z = a tan(theta),
             * dz = dz/dtheta dtheta
             *    = a tan'(theta) dtheta
             *    = a sec^2(theta) dtheta,
             * and dtheta = const */
            const double dz_dtheta = a / gsl_pow_2(cos(theta));
            const double z = a * tan(theta);
            ret += bayestar_volume_render_inner(x, y, z, axis0, axis1, axis2,
                R, nside, nest, prob, mu, sigma, norm) * dz_dtheta;
        }
        ret *= dtheta;
    } else {
        const double dz = max_distance * dtheta / theta_max;
        for (double z = -max_distance; z <= max_distance; z += dz)
        {
            ret += bayestar_volume_render_inner(x, y, z, axis0, axis1, axis2,
                R, nside, nest, prob, mu, sigma, norm);
        }
        ret *= dz;
    }
    ret *= nside2npix64(nside) / (4 * M_PI * sqrt(2 * M_PI));
    return ret;
}


double bayestar_distance_marginal_pdf(
    double r, long long npix,
    const double *prob, const double *mu,
    const double *sigma, const double *norm)
{
    double sum = 0;
    #pragma omp parallel for reduction(+:sum)
    for (long long i = 0; i < npix; i ++)
        sum += prob[i] * bayestar_distance_conditional_pdf(
            r, mu[i], sigma[i], norm[i]);
    return sum;
}


double bayestar_distance_marginal_cdf(
    double r, long long npix,
    const double *prob, const double *mu,
    const double *sigma, const double *norm)
{
    double sum = 0;
    #pragma omp parallel for reduction(+:sum)
    for (long long i = 0; i < npix; i ++)
        sum += prob[i] * bayestar_distance_conditional_cdf(
            r, mu[i], sigma[i], norm[i]);
    return sum;
}


typedef struct {
    double p;
    long long npix;
    const double *prob;
    const double *mu;
    const double *sigma;
    const double *norm;
} marginal_ppf_params;


static void marginal_ppf_fdf(double r, void *params, double *f, double *df)
{
    const marginal_ppf_params *p = (marginal_ppf_params *)params;
    const double _f = bayestar_distance_marginal_cdf(
        r, p->npix, p->prob, p->mu, p->sigma, p->norm);
    const double _df = bayestar_distance_marginal_pdf(
        r, p->npix, p->prob, p->mu, p->sigma, p->norm);
    if (p->p > 0.5)
    {
        *f = log(1 - _f) - log(1 - p->p);
        *df = -_df / (1 - _f);
    } else {
        *f = log(_f) - log(p->p);
        *df = _df / _f;
    }
}


static double marginal_ppf_f(double r, void *params)
{
    double f, df;
    marginal_ppf_fdf(r, params, &f, &df);
    return f;
}


static double marginal_ppf_df(double r, void *params)
{
    double f, df;
    marginal_ppf_fdf(r, params, &f, &df);
    return df;
}


static double marginal_ppf_initial_guess(
    double p, long long npix,
    const double *prob, const double *mu,
    const double *sigma, const double *norm)
{
    /* Find the most probable pixel that has valid distance information. */
    long long max_ipix = -1;
    double max_prob = -INFINITY;
    for (long long ipix = 0; ipix < npix; ipix ++)
    {
        if (isfinite(mu[ipix]) && prob[ipix] > max_prob)
        {
            max_ipix = ipix;
            max_prob = prob[ipix];
        }
    }

    if (max_ipix >= 0)
    {
        return bayestar_distance_conditional_ppf(
            p, mu[max_ipix], sigma[max_ipix], norm[max_ipix]);
    } else {
        /* No pixels with valid distance info found: just guess 100 Mpc. */
        return 100;
    }
}


double bayestar_distance_marginal_ppf(
    double p, long long npix,
    const double *prob, const double *mu,
    const double *sigma, const double *norm)
{
    if (p <= 0)
        return 0;
    else if (p >= 1)
        return GSL_POSINF;
    else if (!isfinite(p))
        return GSL_NAN;

    /* Set up variables for tracking progress toward the solution. */
    static const int max_iter = 50;
    marginal_ppf_params params = {p, npix, prob, mu, sigma, norm};
    int iter = 0;
    double r = marginal_ppf_initial_guess(p, npix, prob, mu, sigma, norm);
    int status;

    /* Set up solver (on stack). */
    const gsl_root_fdfsolver_type *algo = gsl_root_fdfsolver_steffenson;
    char state[algo->size];
    gsl_root_fdfsolver solver = {algo, NULL, 0, state};
    gsl_function_fdf fun = {
        marginal_ppf_f, marginal_ppf_df, marginal_ppf_fdf,
        &params};
    gsl_root_fdfsolver_set(&solver, &fun, r);

    do
    {
        const double rold = r;
        status = gsl_root_fdfsolver_iterate(&solver);
        r = gsl_root_fdfsolver_root(&solver);
        status = gsl_root_test_delta (r, rold, 0, GSL_SQRT_DBL_EPSILON);
        iter++;
    } while (status == GSL_CONTINUE && iter < max_iter);
    /* FIXME: do something with status? */

    return r;
}

src/bayestar_moc.c

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
/*
 * Copyright (C) 2017-2024  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */


#include "bayestar_moc.h"
#include <math.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_math.h>
#include <stdlib.h>
#include <string.h>
#include <chealpix.h>

#include "branch_prediction.h"


int64_t nest2uniq64(uint8_t order, int64_t nest)
{
    if (nest < 0)
        return -1;
    else
        return nest + ((int64_t) 1 << 2 * (order + 1));
}


int8_t uniq2order64(int64_t uniq)
{
    if (uniq < 4)
        return -1;

    int8_t order;
#if defined(__x86_64__) || defined(_M_X64) || defined(__i386) || defined(_M_IX86)
    int64_t o;
    asm("bsrq %1, %0\n\t"
        : "=r" (o)
        : "rm" (uniq));
    order = o;
#else
    order = 63 - __builtin_clzll(uniq);
#endif
    return (order >> 1) - 1;
}


double uniq2pixarea64(int64_t uniq)
{
    int8_t order = uniq2order64(uniq);
    if (order < 0)
        return GSL_NAN;
    else
        return ldexp(M_PI / 3, -2 * order);
}


int8_t uniq2nest64(int64_t uniq, int64_t *nest)
{
    int8_t order = uniq2order64(uniq);
    if (order < 0)
        *nest = -1;
    else
        *nest = uniq - ((int64_t) 1 << 2 * (order + 1));
    return order;
}


void uniq2ang64(int64_t uniq, double *theta, double *phi)
{
    int64_t nest;
    int8_t order = uniq2nest64(uniq, &nest);
    if (order < 0) {
        *theta = *phi = GSL_NAN;
    } else {
        int64_t nside = (int64_t) 1 << order;
        pix2ang_nest64(nside, nest, theta, phi);
    }
}


void *moc_rasterize64(
    const void *pixels, size_t offset, size_t itemsize, size_t len,
    size_t *npix, int8_t order)
{
    /* Calculate pixel size. */
    const size_t pixelsize = offset + itemsize;

    /* If the parameter order >= 0, then rasterize at that order.
     * Otherwise, find maximum order. Note: normally MOC datasets are stored in
     * order of ascending MOC index, so the last pixel should have the highest
     * order. However, our rasterization algorithm doesn't depend on this
     * sorting, so let's just do a linear search for the maximum order. */
    int8_t max_order;
    {
        int64_t max_uniq = 0;
        for (size_t i = 0; i < len; i ++)
        {
            const void *pixel = (const char *) pixels + i * pixelsize;
            const int64_t uniq = *(const int64_t *) pixel;
            if (uniq > max_uniq)
                max_uniq = uniq;
        }
        max_order = uniq2order64(max_uniq);
    }
    if (UNLIKELY(max_order < 0)) {
        GSL_ERROR_NULL("invalid UNIQ value", GSL_EINVAL);
    }

    /* Don't handle downsampling here, because we don't know how to do
     * reduction across pixels without more knowledge of the pixel datatype and
     * contents. */
    if (order >= max_order)
        max_order = order;
    else if (order >= 0)
        GSL_ERROR_NULL("downsampling not implemented", GSL_EUNIMPL);

    /* Allocate output. */
    *npix = 12 * ((size_t) 1 << 2 * max_order);
    void *ret = calloc(*npix, itemsize);
    if (!ret)
        GSL_ERROR_NULL("not enough memory to allocate image", GSL_ENOMEM);

    /* Paint pixels into output. */
    for (size_t i = 0; i < len; i ++)
    {
        const void *pixel = (const char *) pixels + i * pixelsize;
        int64_t nest;
        order = uniq2nest64(*(const int64_t *) pixel, &nest);
        if (UNLIKELY(order < 0)) {
            free(ret);
            GSL_ERROR_NULL("invalid UNIQ value", GSL_EINVAL);
        }
        const size_t reps = (size_t) 1 << 2 * (max_order - order);
        for (size_t j = 0; j < reps; j ++)
            memcpy((char *) ret + (nest * reps + j) * itemsize,
                (const char *) pixel + offset, itemsize);
    }

    return ret;
}

src/bayestar_sky_map.c

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
566  
567  
568  
569  
570  
571  
572  
573  
574  
575  
576  
577  
578  
579  
580  
581  
582  
583  
584  
585  
586  
587  
588  
589  
590  
591  
592  
593  
594  
595  
596  
597  
598  
599  
600  
601  
602  
603  
604  
605  
606  
607  
608  
609  
610  
611  
612  
613  
614  
615  
616  
617  
618  
619  
620  
621  
622  
623  
624  
625  
626  
627  
628  
629  
630  
631  
632  
633  
634  
635  
636  
637  
638  
639  
640  
641  
642  
643  
644  
645  
646  
647  
648  
649  
650  
651  
652  
653  
654  
655  
656  
657  
658  
659  
660  
661  
662  
663  
664  
665  
666  
667  
668  
669  
670  
671  
672  
673  
674  
675  
676  
677  
678  
679  
680  
681  
682  
683  
684  
685  
686  
687  
688  
689  
690  
691  
692  
693  
694  
695  
696  
697  
698  
699  
700  
701  
702  
703  
704  
705  
706  
707  
708  
709  
710  
711  
712  
713  
714  
715  
716  
717  
718  
719  
720  
721  
722  
723  
724  
725  
726  
727  
728  
729  
730  
731  
732  
733  
734  
735  
736  
737  
738  
739  
740  
741  
742  
743  
744  
745  
746  
747  
748  
749  
750  
751  
752  
753  
754  
755  
756  
757  
758  
759  
760  
761  
762  
763  
764  
765  
766  
767  
768  
769  
770  
771  
772  
773  
774  
775  
776  
777  
778  
779  
780  
781  
782  
783  
784  
785  
786  
787  
788  
789  
790  
791  
792  
793  
794  
795  
796  
797  
798  
799  
800  
801  
802  
803  
804  
805  
806  
807  
808  
809  
810  
811  
812  
813  
814  
815  
816  
817  
818  
819  
820  
821  
822  
823  
824  
825  
826  
827  
828  
829  
830  
831  
832  
833  
834  
835  
836  
837  
838  
839  
840  
841  
842  
843  
844  
845  
846  
847  
848  
849  
850  
851  
852  
853  
854  
855  
856  
857  
858  
859  
860  
861  
862  
863  
864  
865  
866  
867  
868  
869  
870  
871  
872  
873  
874  
875  
876  
877  
878  
879  
880  
881  
882  
883  
884  
885  
886  
887  
888  
889  
890  
891  
892  
893  
894  
895  
896  
897  
898  
899  
900  
901  
902  
903  
904  
905  
906  
907  
908  
909  
910  
911  
912  
913  
914  
915  
916  
917  
918  
919  
920  
921  
922  
923  
924  
925  
926  
927  
928  
929  
930  
931  
932  
933  
934  
935  
936  
937  
938  
939  
940  
941  
942  
943  
944  
945  
946  
947  
948  
949  
950  
951  
952  
953  
954  
955  
956  
957  
958  
959  
960  
961  
962  
963  
964  
965  
966  
967  
968  
969  
970  
971  
972  
973  
974  
975  
976  
977  
978  
979  
980  
981  
982  
983  
984  
985  
986  
987  
988  
989  
990  
991  
992  
993  
994  
995  
996  
997  
998  
999  
1000  
1001  
1002  
1003  
1004  
1005  
1006  
1007  
1008  
1009  
1010  
1011  
1012  
1013  
1014  
1015  
1016  
1017  
1018  
1019  
1020  
1021  
1022  
1023  
1024  
1025  
1026  
1027  
1028  
1029  
1030  
1031  
1032  
1033  
1034  
1035  
1036  
1037  
1038  
1039  
1040  
1041  
1042  
1043  
1044  
1045  
1046  
1047  
1048  
1049  
1050  
1051  
1052  
1053  
1054  
1055  
1056  
1057  
1058  
1059  
1060  
1061  
1062  
1063  
1064  
1065  
1066  
1067  
1068  
1069  
1070  
1071  
1072  
1073  
1074  
1075  
1076  
1077  
1078  
1079  
1080  
1081  
1082  
1083  
1084  
1085  
1086  
1087  
1088  
1089  
1090  
1091  
1092  
1093  
1094  
1095  
1096  
1097  
1098  
1099  
1100  
1101  
1102  
1103  
1104  
1105  
1106  
1107  
1108  
1109  
1110  
1111  
1112  
1113  
1114  
1115  
1116  
1117  
1118  
1119  
1120  
1121  
1122  
1123  
1124  
1125  
1126  
1127  
1128  
1129  
1130  
1131  
1132  
1133  
1134  
1135  
1136  
1137  
1138  
1139  
1140  
1141  
1142  
1143  
1144  
1145  
1146  
1147  
1148  
1149  
1150  
1151  
1152  
1153  
1154  
1155  
1156  
1157  
1158  
1159  
1160  
1161  
1162  
1163  
1164  
1165  
1166  
1167  
1168  
1169  
1170  
1171  
1172  
1173  
1174  
1175  
1176  
1177  
1178  
1179  
1180  
1181  
1182  
1183  
1184  
1185  
1186  
1187  
1188  
1189  
1190  
1191  
1192  
1193  
1194  
1195  
1196  
1197  
1198  
1199  
1200  
1201  
1202  
1203  
1204  
1205  
1206  
1207  
1208  
1209  
1210  
1211  
1212  
1213  
1214  
1215  
1216  
1217  
1218  
1219  
1220  
1221  
1222  
1223  
1224  
1225  
1226  
1227  
1228  
1229  
1230  
1231  
1232  
1233  
1234  
1235  
1236  
1237  
1238  
1239  
1240  
1241  
1242  
1243  
1244  
1245  
1246  
1247  
1248  
1249  
1250  
1251  
1252  
1253  
1254  
1255  
1256  
1257  
1258  
1259  
1260  
1261  
1262  
1263  
1264  
1265  
1266  
1267  
1268  
1269  
1270  
1271  
1272  
1273  
1274  
1275  
1276  
1277  
1278  
1279  
1280  
1281  
1282  
1283  
1284  
1285  
1286  
1287  
1288  
1289  
1290  
1291  
1292  
1293  
1294  
1295  
1296  
1297  
1298  
1299  
1300  
1301  
1302  
1303  
1304  
1305  
1306  
1307  
1308  
1309  
1310  
1311  
1312  
1313  
1314  
1315  
1316  
1317  
1318  
1319  
1320  
1321  
1322  
1323  
1324  
1325  
1326  
1327  
1328  
1329  
1330  
1331  
1332  
1333  
1334  
1335  
1336  
1337  
1338  
1339  
1340  
1341  
1342  
1343  
1344  
1345  
1346  
1347  
1348  
1349  
1350  
1351  
1352  
1353  
1354  
1355  
1356  
1357  
1358  
1359  
1360  
1361  
1362  
1363  
1364  
1365  
1366  
1367  
1368  
1369  
1370  
1371  
1372  
1373  
1374  
1375  
1376  
1377  
1378  
1379  
1380  
1381  
1382  
1383  
1384  
1385  
1386  
1387  
1388  
1389  
1390  
1391  
1392  
1393  
1394  
1395  
1396  
1397  
1398  
1399  
1400  
1401  
1402  
1403  
1404  
1405  
1406  
1407  
1408  
1409  
1410  
1411  
1412  
1413  
1414  
1415  
1416  
1417  
1418  
1419  
1420  
1421  
1422  
1423  
1424  
1425  
1426  
1427  
1428  
1429  
1430  
1431  
1432  
1433  
1434  
1435  
1436  
1437  
1438  
1439  
1440  
1441  
1442  
1443  
1444  
1445  
1446  
1447  
1448  
/*                                           >y#
                                            ~'#o+
                                           '~~~md~
                '|+>#!~'':::::....        .~'~'cY#
            .+oy+>|##!~~~''':::......     ~:'':md! .
          #rcmory+>|#~''':::'::...::.::. :..'''Yr:...
        'coRRaamuyb>|!~'''::::':...........  .+n|.::..
       !maMMNMRYmuybb|!~'''':.........::::::: ro'..::..
      .cODDMYouuurub!':::...........:::~'.. |o>::...:..
      >BDNCYYmroyb>|#~:::::::::::::~':.:: :ob::::::::..
      uOCCNAa#'''''||':::.                :oy':::::::::.
    :rRDn!  :~::'y+::':  ... ...:::.     :ob':::::::::::.
   yMYy:   :>yooCY'.':.   .:'':......    ~u+~::::::::::::.
  >>:'. .~>yBDMo!.'': . .:'':.   .      >u|!:::::::::::::.
    ':'~|mYu#:'~'''. :.~:':...         yy>|~:::::::::::::..
    :!ydu>|!rDu::'. +'#~::!#'.~:     |r++>#':::::::::::::..
    mn>>>>>YNo:'': !# >'::::...  ..:cyb++>!:::::::::..:::...
    :ouooyodu:'': .!:.!:::.       yobbbb+>~::::::::....:....
     'cacumo~''' .'~ :~'.::.    :aybbbbbb>':::'~''::::....
      .mamd>'''. :~' :':'.:.   om>bbbyyyb>'.#b>|#~~~'':..
      .yYYo''': .:~' .'::'   .ny>+++byyoao!b+|||#!~~~''''''::.
      .#RUb:''. .:'' .:':   |a#|>>>>yBMdb #yb++b|':::::''':'::::::.
      .'CO!'''  .:'' .'    uu~##|+mMYy>+:|yyo+:::'::.         .::::::
      .:RB~''' ..::'.':   o>~!#uOOu>bby'|yB>.'::  '~!!!!!~':. ..  .::::
       :Rm''': ..:~:!:  'c~~+YNnbyyybb~'mr.':  !+yoy+>||!~'::.       :::.
      ..Oo''': .'' ~:  !+|BDCryuuuuub|#B!::  !rnYaocob|#!~'':.  ..    .::.
      . nB''': :  .'  |dNNduroomnddnuun::.  ydNAMMOary+>#~:.:::...      .:
       .uC~'''    :. yNRmmmadYUROMMBmm.:   bnNDDDMRBoy>|#~':....:.      .:
                 :' ymrmnYUROMAAAAMYn::. .!oYNDDMYmub|!~'::....:..     :
                 !'#booBRMMANDDDNNMO!:. !~#ooRNNAMMOOmuy+#!':::.......    :.
                .!'!#>ynCMNDDDDDNMRu.. '|:!raRMNAMOOdooy+|!~:::........   .:
                 : .'rdbcRMNNNNAMRB!:  |!:~bycmdYYBaoryy+|!~':::.::::.  ..
                 ..~|RMADnnONAMMRdy:. .>#::yyoroccruuybb>#!~'::':...::.
                  :'oMOMOYNMnybyuo!.  :>#::b+youuoyyy+>>|!~':.    :::::
                  ''YMCOYYNMOCCCRdoy##~~~: !b>bb+>>>||#~:..:::     ::::.
                  .:OMRCoRNAMOCROYYUdoy|>~:.~!!~!~~':...:'::::.   :::::.
                  ''oNOYyMNAMMMRYnory+|!!!:.....     ::.  :'::::::::::::
                 .:..uNabOAMMCOdcyb+|!~':::.          !!'.. :~:::::'''':.
                  .   +Y>nOORYauyy>!!'':....           !#~..  .~:''''''':.

****************  ____  _____  ______________________    ____     **************
***************  / __ )/   \ \/ / ____/ ___/_  __/   |  / __ \   ***************
**************  / __  / /| |\  / __/  \__ \ / / / /| | / /_/ /  ****************
*************  / /_/ / ___ |/ / /___ ___/ // / / ___ |/ _, _/  *****************
************  /_____/_/  |_/_/_____//____//_/ /_/  |_/_/ |_|  ******************
*/


/*
 * Copyright (C) 2013-2024  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "bayestar_cosmology.h"
#include "bayestar_sky_map.h"
#include "bayestar_distance.h"
#include "bayestar_moc.h"
#include "omp_interruptible.h"

#include <assert.h>
#include <float.h>
#include <math.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>

#include "cubic_interp.h"

#include <chealpix.h>

#include <gsl/gsl_cdf.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_integration.h>
#include <gsl/gsl_interp.h>
#include <gsl/gsl_math.h>
#include <gsl/gsl_sf_bessel.h>
#include <gsl/gsl_sf_exp.h>
#include <gsl/gsl_sf_expint.h>
#include <gsl/gsl_sf_gamma.h>
#include <gsl/gsl_spline.h>
#include <gsl/gsl_test.h>

#include "branch_prediction.h"

#ifdef WITH_ITTNOTIFY
#include <ittnotify.h>
static __itt_domain *itt_domain;
static __itt_string_handle
    *itt_task_lookup_table,
    *itt_task_initial_step,
    *itt_task_refinement_step,
    *itt_task_final_step,
    *itt_task_log_posterior;

#define ITT_TASK_BEGIN(domain, task) __itt_task_begin((domain), __itt_null, __itt_null, (task))
#define ITT_TASK_END(domain) __itt_task_end((domain))
#else
#define ITT_TASK_BEGIN(domain, task)
#define ITT_TASK_END(domain)
#endif

/* Loop count hints */
#if defined(__INTEL_COMPILER) || defined(__ICL) || defined(__ICC)
#define PRAGMA_LOOP_COUNT_NINT _Pragma("loop count min(1), max(2), avg(2)")
#define PRAGMA_LOOP_COUNT_NIFOS _Pragma("loop count min(1), max(5), avg(2)")
#define PRAGMA_LOOP_COUNT_NSAMPLES _Pragma("loop count min(1), max(128), avg(16)")
#elif defined(__clang__) || defined(__llvm__)
#define PRAGMA_LOOP_COUNT_NINT _Pragma("unroll 2")
#define PRAGMA_LOOP_COUNT_NIFOS _Pragma("unroll 2")
#define PRAGMA_LOOP_COUNT_NSAMPLES _Pragma("unroll 16")
#else /* assume GCC */
#define PRAGMA_LOOP_COUNT_NINT _Pragma("GCC unroll 2")
#define PRAGMA_LOOP_COUNT_NIFOS _Pragma("GCC unroll 2")
#define PRAGMA_LOOP_COUNT_NSAMPLES _Pragma("GCC unroll 16")
#endif


/* Compute |z|^2. Hopefully a little faster than gsl_pow_2(cabs(z)), because no
 * square roots are necessary. */
static float cabs2(float complex z) {
    const float realpart = crealf(z), imagpart = cimagf(z);
    return realpart * realpart + imagpart * imagpart;
}


static float complex exp_i(float phi) {
    return cosf(phi) + I * sinf(phi);
}


/*
 * Catmull-Rom cubic spline interpolant of x(t) for regularly gridded
 * samples x_i(t_i), assuming:
 *
 *     t_0 = -1, x_0 = x[0],
 *     t_1 = 0,  x_1 = x[1],
 *     t_2 = 1,  x_2 = x[2],
 *     t_3 = 2,  x_3 = x[3].
 */
static float catrom(float x0, float x1, float x2, float x3, float t) {
    return x1
        + t*(-0.5f*x0 + 0.5f*x2
        + t*(x0 - 2.5f*x1 + 2.0f*x2 - 0.5f*x3
        + t*(-0.5f*x0 + 1.5f*x1 - 1.5f*x2 + 0.5f*x3)));
}


/* Evaluate a complex time series using cubic spline interpolation, assuming
 * that the vector x gives the samples of the time series at times
 * 0, 1, ..., nsamples-1. */
static float complex eval_snr(const float (*x)[2], size_t nsamples, float t) {
    ssize_t i;
    float f;
    float complex y;

    /* Break |t| into integer and fractional parts. */
    {
        float dbl_i;
        f = modff(t, &dbl_i);
        i = dbl_i;
    }

    if (i >= 1 && i < (ssize_t)nsamples - 2)
        y = catrom(x[i-1][0], x[i][0], x[i+1][0], x[i+2][0], f) * exp_i(
            catrom(x[i-1][1], x[i][1], x[i+1][1], x[i+2][1], f));
    else
        y = 0;

    return y;
}


typedef struct {
    bicubic_interp *region0;
    cubic_interp *region1;
    cubic_interp *region2;
    double ymax, vmax, p0_limit;
} log_radial_integrator;


typedef struct {
    double scale;
    double p;
    double b;
    int k, cosmology;
} radial_integrand_params;


/* Uniform-in-comoving volume prior for the Planck15 cosmology.
 * This is implemented as a cubic spline interpolant.
 *
 * The following static variables are defined in bayestar_cosmology.h, which
 * is automatically generated by bayestar_cosmology.py:
 *     - dVC_dVL_data
 *     - dVC_dVL_tmin
 *     - dVC_dVL_dt
 *     - dVC_dVL_high_z_slope
 *     - dVC_dVL_high_z_intercept
 */
static gsl_spline *dVC_dVL_interp = NULL;
static void dVC_dVL_init(void)
{
    const size_t len = sizeof(dVC_dVL_data) / sizeof(*dVC_dVL_data);
    dVC_dVL_interp = gsl_spline_alloc(gsl_interp_cspline, len);
    assert(dVC_dVL_interp);
    double x[len];
    for (size_t i = 0; i < len; i ++)
        x[i] = dVC_dVL_tmin + i * dVC_dVL_dt;
    int ret = gsl_spline_init(dVC_dVL_interp, x, dVC_dVL_data, len);
    assert(ret == GSL_SUCCESS);
    (void)ret; /* Silence unused variable warning */
}


static double log_dVC_dVL(double DL)
{
    const double log_DL = log(DL);
    if (log_DL <= dVC_dVL_tmin)
    {
        return 0.0;
    } else if (log_DL >= dVC_dVL_tmax) {
        return dVC_dVL_high_z_slope * log_DL + dVC_dVL_high_z_intercept;
    } else {
        return gsl_spline_eval(dVC_dVL_interp, log_DL, NULL);
    }
}


static double radial_integrand(double r, void *params)
{
    const radial_integrand_params *integrand_params = params;
    const double scale = integrand_params->scale;
    const double p = integrand_params->p;
    const double b = integrand_params->b;
    const int k = integrand_params->k;
    double ret = scale - gsl_pow_2(p / r - 0.5 * b / p);
    if (integrand_params->cosmology)
        ret += log_dVC_dVL(r);
    return gsl_sf_exp_mult(
        ret, gsl_sf_bessel_I0_scaled(b / r) * gsl_pow_int(r, k));
}


static double log_radial_integrand(double r, void *params)
{
    const radial_integrand_params *integrand_params = params;
    const double scale = integrand_params->scale;
    const double p = integrand_params->p;
    const double b = integrand_params->b;
    const int k = integrand_params->k;
    double ret = log(gsl_sf_bessel_I0_scaled(b / r) * gsl_pow_int(r, k))
        + scale - gsl_pow_2(p / r - 0.5 * b / p);
    if (integrand_params->cosmology)
        ret += log_dVC_dVL(r);
    return ret;
}


static double log_radial_integral(double r1, double r2, double p, double b, int k, int cosmology)
{
    radial_integrand_params params = {0, p, b, k, cosmology};
    double breakpoints[5];
    unsigned char nbreakpoints = 0;
    double result = 0, abserr, log_offset = -INFINITY;
    int ret;

    if (LIKELY(b != 0)) {
        /* Calculate the approximate distance at which the integrand attains a
         * maximum (middle) and a fraction eta of the maximum (left and right).
         * This neglects the scaled Bessel function factors and the power-law
         * distance prior. It assumes that the likelihood is approximately of
         * the form
         *
         *    -p^2/r^2 + B/r.
         *
         * Then the middle breakpoint occurs at 1/r = -B/2A, and the left and
         * right breakpoints occur when
         *
         *   A/r^2 + B/r = log(eta) - B^2/4A.
         */

        static const double eta = 0.01;
        const double middle = 2 * gsl_pow_2(p) / b;
        const double left = 1 / (1 / middle + sqrt(-log(eta)) / p);
        const double right = 1 / (1 / middle - sqrt(-log(eta)) / p);

        /* Use whichever of the middle, left, and right points lie within the
         * integration limits as initial subdivisions for the adaptive
         * integrator. */

        breakpoints[nbreakpoints++] = r1;
        if(left > breakpoints[nbreakpoints-1] && left < r2)
            breakpoints[nbreakpoints++] = left;
        if(middle > breakpoints[nbreakpoints-1] && middle < r2)
            breakpoints[nbreakpoints++] = middle;
        if(right > breakpoints[nbreakpoints-1] && right < r2)
            breakpoints[nbreakpoints++] = right;
        breakpoints[nbreakpoints++] = r2;
    } else {
        /* Inner breakpoints are undefined because b = 0. */
        breakpoints[nbreakpoints++] = r1;
        breakpoints[nbreakpoints++] = r2;
    }

    /* Re-scale the integrand so that the maximum value at any of the
     * breakpoints is 1. Note that the initial value of the constant term
     * is overwritten. */

    for (unsigned char i = 0; i < nbreakpoints; i++)
    {
        double new_log_offset = log_radial_integrand(breakpoints[i], &params);
        if (new_log_offset > log_offset)
            log_offset = new_log_offset;
    }

    /* If the largest value of the log integrand was -INFINITY, then the
     * integrand is 0 everywhere. Set log_offset to 0, because subtracting
     * -INFINITY would make the integrand infinite. */
    if (log_offset == -INFINITY)
        log_offset = 0;

    params.scale = -log_offset;

    {
        /* Maximum number of subdivisions for adaptive integration. */
        static const size_t n = 64;

        /* Allocate workspace on stack. Hopefully, a little bit faster than
         * using the heap in multi-threaded code. */

        double alist[n];
        double blist[n];
        double rlist[n];
        double elist[n];
        size_t order[n];
        size_t level[n];
        gsl_integration_workspace workspace = {
            .alist = alist,
            .blist = blist,
            .rlist = rlist,
            .elist = elist,
            .order = order,
            .level = level,
            .limit = n
        };

        /* Set up integrand data structure. */
        const gsl_function func = {radial_integrand, &params};

        /* Perform adaptive Gaussian quadrature. */
        ret = gsl_integration_qagp(&func, breakpoints, nbreakpoints,
            DBL_MIN, 1e-8, n, &workspace, &result, &abserr);

        /* FIXME: do we care to keep the error estimate around? */
    }

    /* FIXME: do something with ret */
    (void)ret;

    /* Done! */
    return log_offset + log(result);
}


static const size_t default_log_radial_integrator_size = 400;


static log_radial_integrator *log_radial_integrator_init(double r1, double r2, int k, int cosmology, double pmax, size_t size)
{
    log_radial_integrator *integrator;
    bicubic_interp *region0 = NULL;
    cubic_interp *region1 = NULL, *region2 = NULL;
    const double alpha = 4;
    const double p0 = 0.5 * (k >= 0 ? r2 : r1);
    const double xmax = log(pmax);
    const double x0 = GSL_MIN_DBL(log(p0), xmax);
    const double xmin = x0 - (1 + M_SQRT2) * alpha;
    const double ymax = x0 + alpha;
    const double ymin = 2 * x0 - M_SQRT2 * alpha - xmax;
    const double d = (xmax - xmin) / (size - 1); /* dx = dy = du */
    const double umin = - (1 + M_SQRT1_2) * alpha;
    const double vmax = x0 - M_SQRT1_2 * alpha;
    double z0[size][size], z1[size], z2[size];
    double p0_limit;

    if (UNLIKELY(k == -1))
    {
        p0_limit = log(log(r2 / r1));
    } else {
        int k1 = k + 1;
        p0_limit = log((gsl_pow_int(r2, k1) - gsl_pow_int(r1, k1)) / k1);
    }

    /* const double umax = xmax - vmax; */ /* unused */

    int interrupted;
    OMP_BEGIN_INTERRUPTIBLE
    integrator = malloc(sizeof(*integrator));

    #pragma omp taskloop collapse(2) shared(z0)
    for (size_t ix = 0; ix < size; ix ++)
    {
        for (size_t iy = 0; iy < size; iy ++)
        {
            if (OMP_WAS_INTERRUPTED)
                OMP_EXIT_LOOP_EARLY;

            const double x = xmin + ix * d;
            const double y = ymin + iy * d;
            const double p = exp(x);
            const double r0 = exp(y);
            const double b = 2 * gsl_pow_2(p) / r0;
            /* Note: using this where p > r0; could reduce evaluations by half */
            z0[ix][iy] = log_radial_integral(r1, r2, p, b, k, cosmology);
        }
    }

    if (OMP_WAS_INTERRUPTED)
        goto done;

    region0 = bicubic_interp_init(*z0, size, size, xmin, ymin, d, d);

    for (size_t i = 0; i < size; i ++)
        z1[i] = z0[i][size - 1];
    region1 = cubic_interp_init(z1, size, xmin, d);

    for (size_t i = 0; i < size; i ++)
        z2[i] = z0[i][size - 1 - i];
    region2 = cubic_interp_init(z2, size, umin, d);

done:
    interrupted = OMP_WAS_INTERRUPTED;
    OMP_END_INTERRUPTIBLE

    if (UNLIKELY(interrupted || !(integrator && region0 && region1 && region2)))
    {
        free(integrator);
        free(region0);
        free(region1);
        free(region2);
        GSL_ERROR_NULL("not enough memory to allocate integrator", GSL_ENOMEM);
    }

    integrator->region0 = region0;
    integrator->region1 = region1;
    integrator->region2 = region2;
    integrator->ymax = ymax;
    integrator->vmax = vmax;
    integrator->p0_limit = p0_limit;
    return integrator;
}


static void log_radial_integrator_free(log_radial_integrator *integrator)
{
    if (LIKELY(integrator))
    {
        bicubic_interp_free(integrator->region0);
        integrator->region0 = NULL;
        cubic_interp_free(integrator->region1);
        integrator->region1 = NULL;
        cubic_interp_free(integrator->region2);
        integrator->region2 = NULL;
    }
    free(integrator);
}


static double log_radial_integrator_eval(const log_radial_integrator *integrator, double p, double b, double log_p, double log_b)
{
    assert(p >= 0);

    if (LIKELY(p > 0)) {
        const double x = log_p;
        const double y = M_LN2 + 2 * log_p - log_b;
        double result = gsl_pow_2(0.5 * b / p);
        if (y >= integrator->ymax) {
            result += cubic_interp_eval(integrator->region1, x);
        } else {
            const double v = 0.5 * (x + y);
            if (v <= integrator->vmax)
            {
                const double u = 0.5 * (x - y);
                result += cubic_interp_eval(integrator->region2, u);
            } else {
                result += bicubic_interp_eval(integrator->region0, x, y);
            }
        }
        return result;
    } else {
        /* note: p2 == 0 implies b == 0 */
        assert(b < GSL_DBL_EPSILON);
        return integrator->p0_limit;
    }
}


/* Find error in time of arrival. */
static void toa_errors(
    double *dt,
    double theta,
    double phi,
    double gmst,
    int nifos,
    const double **locs, /* Input: detector position. */
    const double *toas /* Input: time of arrival. */
) {
    /* Convert to Cartesian coordinates. */
    double n[3];
    ang2vec(theta, phi - gmst, n);

    PRAGMA_LOOP_COUNT_NIFOS
    for (int i = 0; i < nifos; i ++)
    {
        double dot = 0;
        for (int j = 0; j < 3; j ++)
        {
            dot += locs[i][j] * n[j];
        }
        dt[i] = toas[i] + dot;
    }
}


/* Compute antenna factors from the detector response tensor and source
 * sky location, and return as a complex number F_plus + i F_cross. */
float complex antenna_factor(
    const float D[3][3],
    float ra,
    float dec,
    float gmst
) {
    /* Adapted from LAL's XLALComputeDetAMResponse with the following changes:
     * - All operations are single-precision rather than double-precision.
     * - psi is assumed to be 0.
     * - fplus and fcross are packed into a complex number.
     */
    const float gha = gmst - ra;
    const float cosgha = cosf(gha);
    const float singha = sinf(gha);
    const float cosdec = cosf(dec);
    const float sindec = sinf(dec);
    const float X[] = {-singha, -cosgha, 0};
    const float Y[] = {-cosgha * sindec, singha * sindec, cosdec};
    float complex F = 0;
    for(int i = 0; i < 3; i++) {
        const float DX = D[i][0] * X[0] + D[i][1] * X[1] + D[i][2] * X[2];
        const float DY = D[i][0] * Y[0] + D[i][1] * Y[1] + D[i][2] * Y[2];
        F += (X[i] * DX - Y[i] * DY) + (X[i] * DY + Y[i] * DX) * I;
    }
    return F;
}


/* Expression for complex amplitude on arrival (without 1/distance factor) */
float complex bayestar_signal_amplitude_model(
    float complex F,               /* Complex antenna factor */
    float complex exp_i_twopsi,    /* e^(i*2*psi), for polarization angle psi */
    float u,                       /* cos(inclination) */
    float u2                       /* cos^2(inclination */
) {
    const float complex tmp = F * conjf(exp_i_twopsi);
    return 0.5f * (1 + u2) * crealf(tmp) - I * u * cimagf(tmp);
}


#define nu 10
static const unsigned int ntwopsi = 10;
static float u_points_weights[nu][2];


static void u_points_weights_init(void)
{
    /* Look up Gauss-Legendre quadrature rule for integral over cos(i). */
    gsl_integration_glfixed_table *gltable
        = gsl_integration_glfixed_table_alloc(nu);

    /* Don't bother checking the return value. GSL has static, precomputed
     * values for certain orders, and for the order I have picked it will
     * return a pointer to one of these. See:
     *
     * http://git.savannah.gnu.org/cgit/gsl.git/tree/integration/glfixed.c
     */
    assert(gltable);
    assert(gltable->precomputed); /* We don't have to free it. */

    for (unsigned int iu = 0; iu < nu; iu++)
    {
        double point, weight;

        /* Look up Gauss-Legendre abscissa and weight. */
        int ret = gsl_integration_glfixed_point(
            -1, 1, iu, &point, &weight, gltable);

        /* Don't bother checking return value; the only
         * possible failure is in index bounds checking. */
        assert(ret == GSL_SUCCESS);
		(void)ret; /* Silence unused variable warning */

        u_points_weights[iu][0] = point;
        u_points_weights[iu][1] = log(weight);
    }
}


/* Compare two pixels by contained probability. */
static int bayestar_pixel_compare_prob(const void *a, const void *b)
{
    const bayestar_pixel *apix = a;
    const bayestar_pixel *bpix = b;

    const double delta_logp = (apix->value[0] - bpix->value[0])
        - 2 * M_LN2 * (uniq2order64(apix->uniq) - uniq2order64(bpix->uniq));

    if (delta_logp < 0)
        return -1;
    else if (delta_logp > 0)
        return 1;
    else
        return 0;
}


static void bayestar_pixels_sort_prob(bayestar_pixel *pixels, size_t len)
{
    qsort(pixels, len, sizeof(bayestar_pixel), bayestar_pixel_compare_prob);
}


/* Compare two pixels by contained probability. */
static int bayestar_pixel_compare_uniq(const void *a, const void *b)
{
    const bayestar_pixel *apix = a;
    const bayestar_pixel *bpix = b;
    const unsigned long long auniq = apix->uniq;
    const unsigned long long buniq = bpix->uniq;

    if (auniq < buniq)
        return -1;
    else if (auniq > buniq)
        return 1;
    else
        return 0;
}


static void bayestar_pixels_sort_uniq(bayestar_pixel *pixels, size_t len)
{
    qsort(pixels, len, sizeof(bayestar_pixel), bayestar_pixel_compare_uniq);
}


static void *realloc_or_free(void *ptr, size_t size)
{
    void *new_ptr = realloc(ptr, size);
    if (UNLIKELY(!new_ptr))
    {
        free(ptr);
        GSL_ERROR_NULL("not enough memory to resize array", GSL_ENOMEM);
    }
    return new_ptr;
}


/* Subdivide the final last_n pixels of an adaptively refined sky map. */
static bayestar_pixel *bayestar_pixels_refine(
    bayestar_pixel *pixels, size_t *len, size_t last_n
) {
    assert(last_n <= *len);

    /* New length: adding 4*last_n new pixels, removing last_n old pixels. */
    const size_t new_len = *len + 3 * last_n;
    const size_t new_size = new_len * sizeof(bayestar_pixel);

    pixels = realloc_or_free(pixels, new_size);
    if (LIKELY(pixels))
    {
        for (size_t i = 0; i < last_n; i ++)
        {
            const int64_t uniq = 4 * pixels[*len - i - 1].uniq;
            for (unsigned char j = 0; j < 4; j ++)
                pixels[new_len - (4 * i + j) - 1].uniq = j + uniq;
        }
        *len = new_len;
    }
    return pixels;
}


static bayestar_pixel *bayestar_pixels_alloc(size_t *len, unsigned char order)
{
    const int64_t nside = (int64_t)1 << order;
    const int64_t npix = nside2npix64(nside);
    const size_t size = npix * sizeof(bayestar_pixel);

    bayestar_pixel *pixels = malloc(size);
    if (UNLIKELY(!pixels))
        GSL_ERROR_NULL("not enough memory to allocate sky map", GSL_ENOMEM);

    *len = npix;
    for (long long ipix = 0; ipix < npix; ipix ++)
        pixels[ipix].uniq = nest2uniq64(order, ipix);
    return pixels;
}


static void logsumexp(const double *accum, double log_weight, double *result, unsigned long ni, unsigned long nj)
{
    double max_accum[nj];
    for (unsigned long j = 0; j < nj; j ++)
        max_accum[j] = -INFINITY;
    for (unsigned long i = 0; i < ni; i ++)
        for (unsigned long j = 0; j < nj; j ++)
            if (accum[i * nj + j] > max_accum[j])
                max_accum[j] = accum[i * nj + j];
    double sum_accum[nj];
    for (unsigned long j = 0; j < nj; j ++)
        sum_accum[j] = 0;
    for (unsigned long i = 0; i < ni; i ++)
        for (unsigned long j = 0; j < nj; j ++)
            sum_accum[j] += exp(accum[i * nj + j] - max_accum[j]);
    for (unsigned long j = 0; j < nj; j ++)
        result[j] = log(sum_accum[j]) + max_accum[j] + log_weight;
}


static void bayestar_sky_map_toa_phoa_snr_pixel(
    log_radial_integrator *integrators[],
    unsigned char nint,
    int64_t uniq,
    double *const value,
    double gmst,
    unsigned int nifos,
    unsigned long nsamples,
    float sample_rate,
    const double *epochs,
    const float (**snrs)[2],
    const float (**responses)[3],
    const double **locations,
    const double *horizons,
    float rescale_loglikelihood
) {
    float complex F[nifos];
    float complex snrs_interp[nsamples][nifos];

    {
        double dt[nifos];
        double theta, phi;
        uniq2ang64(uniq, &theta, &phi);

        /* Look up antenna factors */
        PRAGMA_LOOP_COUNT_NIFOS
        for (unsigned int iifo = 0; iifo < nifos; iifo++)
            F[iifo] = antenna_factor(
                responses[iifo], phi, M_PI_2-theta, gmst) * horizons[iifo];

        toa_errors(dt, theta, phi, gmst, nifos, locations, epochs);

        /* Shift SNR time series by the time delay for this sky position */
        PRAGMA_LOOP_COUNT_NSAMPLES
        for (unsigned long isample = 0; isample < nsamples; isample++)
            PRAGMA_LOOP_COUNT_NIFOS
            for (unsigned int iifo = 0; iifo < nifos; iifo++)
                snrs_interp[isample][iifo] = eval_snr(
                    snrs[iifo], nsamples,
                    isample - dt[iifo] * sample_rate - 0.5 * (nsamples - 1));
    }

    float p[ntwopsi][nu], log_p[ntwopsi][nu];
    float b[ntwopsi][nu][nsamples], log_b[ntwopsi][nu][nsamples];
    for (unsigned int itwopsi = 0; itwopsi < ntwopsi; itwopsi++)
    {
        const float twopsi = (2 * M_PI / ntwopsi) * itwopsi;
        const float complex exp_i_twopsi = exp_i(twopsi);

        for (unsigned int iu = 0; iu < nu; iu++)
        {
            const float u = u_points_weights[iu][0];
            const float u2 = u * u;
            float complex z_times_r[nifos];
            float p2 = 0;

            PRAGMA_LOOP_COUNT_NIFOS
            for (unsigned int iifo = 0; iifo < nifos; iifo ++)
            {
                p2 += cabs2(
                    z_times_r[iifo] = bayestar_signal_amplitude_model(
                        F[iifo], exp_i_twopsi, u, u2));
            }
            p2 *= 0.5f;
            p2 *= rescale_loglikelihood * rescale_loglikelihood;
            log_p[itwopsi][iu] = logf(p[itwopsi][iu] = sqrtf(p2));

            PRAGMA_LOOP_COUNT_NSAMPLES
            for (unsigned long isample = 0; isample < nsamples; isample++)
            {
                float complex I0arg_complex_times_r = 0;

                PRAGMA_LOOP_COUNT_NIFOS
                for (unsigned int iifo = 0; iifo < nifos; iifo ++)
                    I0arg_complex_times_r += conjf(z_times_r[iifo]) * snrs_interp[isample][iifo];
                log_b[itwopsi][iu][isample] = logf(b[itwopsi][iu][isample] = cabsf(I0arg_complex_times_r) * rescale_loglikelihood * rescale_loglikelihood);
            }
        }
    }

    double accum[nint][ntwopsi][nu][nsamples];
    PRAGMA_LOOP_COUNT_NINT
    for (unsigned int iint = 0; iint < nint; iint ++)
        for (unsigned int itwopsi = 0; itwopsi < ntwopsi; itwopsi++)
            for (unsigned int iu = 0; iu < nu; iu++)
                PRAGMA_LOOP_COUNT_NSAMPLES
                for (unsigned long isample = 0; isample < nsamples; isample++)
                    accum[iint][itwopsi][iu][isample] = u_points_weights[iu][1] + log_radial_integrator_eval(integrators[iint], p[itwopsi][iu], b[itwopsi][iu][isample], log_p[itwopsi][iu], log_b[itwopsi][iu][isample]);

    PRAGMA_LOOP_COUNT_NINT
    for (unsigned int iint = 0; iint < nint; iint ++)
    {
        double max_accum = -INFINITY;

        for (unsigned int itwopsi = 0; itwopsi < ntwopsi; itwopsi++)
            for (unsigned int iu = 0; iu < nu; iu++)
                PRAGMA_LOOP_COUNT_NSAMPLES
                for (unsigned long isample = 0; isample < nsamples; isample++)
                    if (accum[iint][itwopsi][iu][isample] > max_accum)
                        max_accum = accum[iint][itwopsi][iu][isample];

        double accum1 = 0;

        for (unsigned int itwopsi = 0; itwopsi < ntwopsi; itwopsi++)
            for (unsigned int iu = 0; iu < nu; iu++)
                PRAGMA_LOOP_COUNT_NSAMPLES
                for (unsigned long isample = 0; isample < nsamples; isample++)
                    accum1 += exp(accum[iint][itwopsi][iu][isample] - max_accum);

        value[iint] = log(accum1) + max_accum;
    }
}


static pthread_once_t bayestar_init_once = PTHREAD_ONCE_INIT;
static void bayestar_init_func(void)
{
    dVC_dVL_init();
    u_points_weights_init();

#ifdef WITH_ITTNOTIFY
    itt_domain = __itt_domain_create("ligo.skymap.bayestar");
    itt_task_lookup_table = __itt_string_handle_create("generating lookup table");
    itt_task_initial_step = __itt_string_handle_create("initial resolution step");
    itt_task_refinement_step = __itt_string_handle_create("resolution refinement step");
    itt_task_final_step = __itt_string_handle_create("final resolution step");
    itt_task_log_posterior = __itt_string_handle_create("log likelihood");
#endif
}
static void bayestar_init(void)
{
    int ret = pthread_once(&bayestar_init_once, bayestar_init_func);
    assert(ret == 0);
    (void)ret; /* Silence unsigned variable warning */
}


bayestar_pixel *bayestar_sky_map_toa_phoa_snr(
    size_t *out_len,                /* Number of returned pixels */
    double *out_log_bci,            /* log Bayes factor: coherent vs. incoherent */
    double *out_log_bsn,            /* log Bayes factor: signal vs. noise */
    /* Prior */
    double min_distance,            /* Minimum distance */
    double max_distance,            /* Maximum distance */
    int prior_distance_power,       /* Power of distance in prior */
    int cosmology,                  /* Set to nonzero to include comoving volume correction */
    /* Data */
    double gmst,                    /* GMST (rad) */
    unsigned int nifos,             /* Number of detectors */
    unsigned long nsamples,         /* Length of SNR series */
    float sample_rate,              /* Sample rate in seconds */
    const double *epochs,           /* Timestamps of SNR time series */
    const float (**snrs)[2],        /* SNR amplitude and phase arrays */
    const float (**responses)[3],   /* Detector responses */
    const double **locations,       /* Barycentered Cartesian geographic detector positions (light seconds) */
    const double *horizons,         /* SNR=1 horizon distances for each detector */
    float rescale_loglikelihood                     /* SNR rescale_loglikelihood factor */
) {
    /* Initialize precalculated tables. */
    bayestar_init();

    if (cosmology && prior_distance_power != 2)
    {
        GSL_ERROR_NULL(
            "BAYESTAR supports cosmological priors only for for prior_distance_power=2",
            GSL_EINVAL);
    }
    log_radial_integrator *integrators[] = {NULL, NULL, NULL};
    ITT_TASK_BEGIN(itt_domain, itt_task_lookup_table);
    {
        double pmax = 0;

        PRAGMA_LOOP_COUNT_NIFOS
        for (unsigned int iifo = 0; iifo < nifos; iifo ++)
        {
            pmax += gsl_pow_2(horizons[iifo]);
        }
        pmax = sqrt(0.5 * pmax);
        pmax *= rescale_loglikelihood;

        #pragma omp parallel for
        for (unsigned char k = 0; k < 3; k ++)
        {
            integrators[k] = log_radial_integrator_init(
                min_distance, max_distance, prior_distance_power + k, cosmology,
                pmax, default_log_radial_integrator_size);
        }
    }
    ITT_TASK_END(itt_domain);
    for (unsigned char k = 0; k < 3; k ++)
    {
        if (!integrators[k])
        {
            for (unsigned char kk = 0; kk < k; kk ++)
                log_radial_integrator_free(integrators[kk]);
            return NULL;
        }
    }

    static const unsigned char order0 = 4;
    size_t len;
    bayestar_pixel *pixels = bayestar_pixels_alloc(&len, order0);
    if (!pixels)
    {
        for (unsigned char k = 0; k < 3; k ++)
            log_radial_integrator_free(integrators[k]);
        return NULL;
    }
    const unsigned long npix0 = len;

    OMP_BEGIN_INTERRUPTIBLE

    /* Logarithm of the normalization factor for the prior. */
    const double log_norm = -log(
            2                           /* inclination */
            * (2 * M_PI)                /* coalescence phase? */
            * (4 * M_PI) * ntwopsi      /* polarization angle */
            * nsamples                  /* time samples */
        ) - log_radial_integrator_eval( /* distance */
            integrators[0], 0, 0, -INFINITY, -INFINITY
        );

   /* At the lowest order, compute both the coherent probability map and the
    * incoherent evidence. */
    double log_evidence_coherent, log_evidence_incoherent[nifos];
    {
        double accum[npix0][nifos];

        ITT_TASK_BEGIN(itt_domain, itt_task_initial_step);
        #pragma omp parallel for schedule(guided)
        for (unsigned long i = 0; i < npix0; i ++)
        {
            if (OMP_WAS_INTERRUPTED)
                OMP_EXIT_LOOP_EARLY;

            bayestar_sky_map_toa_phoa_snr_pixel(integrators, 1, pixels[i].uniq,
                pixels[i].value, gmst, nifos, nsamples, sample_rate, epochs,
                snrs, responses, locations, horizons, rescale_loglikelihood);

            PRAGMA_LOOP_COUNT_NIFOS
            for (unsigned int iifo = 0; iifo < nifos; iifo ++)
            {
                bayestar_sky_map_toa_phoa_snr_pixel(integrators, 1,
                    pixels[i].uniq, &accum[i][iifo], gmst, 1, nsamples,
                    sample_rate, &epochs[iifo], &snrs[iifo], &responses[iifo],
                    &locations[iifo], &horizons[iifo], rescale_loglikelihood);
            }
        }
        ITT_TASK_END(itt_domain);

        if (OMP_WAS_INTERRUPTED)
            goto done;

        const double log_weight = log_norm + log(uniq2pixarea64(pixels[0].uniq));

        logsumexp(*accum, log_weight, log_evidence_incoherent, npix0, nifos);
    }

    /* Sort pixels by ascending posterior probability. */
    bayestar_pixels_sort_prob(pixels, len);

    /* Adaptively refine until order=11 (nside=2048). */
    for (unsigned char level = order0; level < 11; level ++)
    {
        /* Adaptively refine the pixels that contain the most probability. */
        pixels = bayestar_pixels_refine(pixels, &len, npix0 / 4);
        if (!pixels)
            goto done;

        ITT_TASK_BEGIN(itt_domain, itt_task_refinement_step);
        #pragma omp parallel for schedule(guided)
        for (unsigned long i = len - npix0; i < len; i ++)
        {
            if (OMP_WAS_INTERRUPTED)
                OMP_EXIT_LOOP_EARLY;

            bayestar_sky_map_toa_phoa_snr_pixel(integrators, 1, pixels[i].uniq,
                pixels[i].value, gmst, nifos, nsamples, sample_rate, epochs,
                snrs, responses, locations, horizons, rescale_loglikelihood);
        }
        ITT_TASK_END(itt_domain);

        if (OMP_WAS_INTERRUPTED)
            goto done;

        /* Sort pixels by ascending posterior probability. */
        bayestar_pixels_sort_prob(pixels, len);
    }

    /* Evaluate distance layers. */
    ITT_TASK_BEGIN(itt_domain, itt_task_final_step);
    #pragma omp parallel for schedule(guided)
    for (unsigned long i = 0; i < len; i ++)
    {
        if (OMP_WAS_INTERRUPTED)
            OMP_EXIT_LOOP_EARLY;

        bayestar_sky_map_toa_phoa_snr_pixel(&integrators[1], 2, pixels[i].uniq,
            &pixels[i].value[1], gmst, nifos, nsamples, sample_rate, epochs,
            snrs, responses, locations, horizons, rescale_loglikelihood);
    }
    ITT_TASK_END(itt_domain);

done:
    for (unsigned char k = 0; k < 3; k ++)
        log_radial_integrator_free(integrators[k]);

    if (OMP_WAS_INTERRUPTED)
    {
        free(pixels);
        pixels = NULL;
    }

    if (pixels)
    {
        /* Rescale so that log(max) = 0. */
        const double max_logp = pixels[len - 1].value[0];
        for (ssize_t i = (ssize_t)len - 1; i >= 0; i --)
            for (unsigned char k = 0; k < 3; k ++)
                pixels[i].value[k] -= max_logp;

        /* Determine normalization of map. */
        double norm = 0;
        for (ssize_t i = (ssize_t)len - 1; i >= 0; i --)
        {
            const double dA = uniq2pixarea64(pixels[i].uniq);
            const double dP = gsl_sf_exp_mult(pixels[i].value[0], dA);
            if (dP <= 0)
                break; /* We have reached underflow. */
            norm += dP;
        }
        log_evidence_coherent = log(norm) + max_logp + log_norm;
        norm = 1 / norm;

        /* Rescale, normalize, and prepare output. */
        for (ssize_t i = (ssize_t)len - 1; i >= 0; i --)
        {
            const double prob = gsl_sf_exp_mult(pixels[i].value[0], norm);
            double rmean = exp(pixels[i].value[1] - pixels[i].value[0]);
            double rstd = exp(pixels[i].value[2] - pixels[i].value[0]) - gsl_pow_2(rmean);
            if (rstd >= 0)
            {
                rstd = sqrt(rstd);
            } else {
                rmean = INFINITY;
                rstd = 1;
            }
            pixels[i].value[0] = prob;
            pixels[i].value[1] = rmean;
            pixels[i].value[2] = rstd;
        }

        /* Sort pixels by ascending NUNIQ index. */
        bayestar_pixels_sort_uniq(pixels, len);

        /* Calculate log Bayes factor. */
        *out_log_bci = *out_log_bsn = log_evidence_coherent;

        PRAGMA_LOOP_COUNT_NIFOS
        for (unsigned int i = 0; i < nifos; i ++)
            *out_log_bci -= log_evidence_incoherent[i];

        /* Done! */
        *out_len = len;
    }

    OMP_END_INTERRUPTIBLE

    return pixels;
}


double bayestar_log_posterior_toa_phoa_snr(
    /* Parameters */
    double ra,                      /* Right ascension (rad) */
    double sin_dec,                 /* Sin(declination) */
    double distance,                /* Distance */
    double u,                       /* Cos(inclination) */
    double twopsi,                  /* Twice polarization angle (rad) */
    double t,                       /* Barycentered arrival time (s) */
    /* Prior */
    double min_distance,            /* Minimum distance */
    double max_distance,            /* Maximum distance */
    int prior_distance_power,       /* Power of distance in prior */
    int cosmology,                  /* Set to nonzero to include comoving volume correction */
    /* Data */
    double gmst,                    /* GMST (rad) */
    unsigned int nifos,             /* Number of detectors */
    unsigned long nsamples,         /* Lengths of SNR series */
    double sample_rate,             /* Sample rate in seconds */
    const double *epochs,           /* Timestamps of SNR time series */
    const float (**snrs)[2],        /* SNR amplitude and phase arrays */
    const float (**responses)[3],   /* Detector responses */
    const double **locations,       /* Barycentered Cartesian geographic detector positions (light seconds) */
    const double *horizons,         /* SNR=1 horizon distances for each detector */
    float rescale_loglikelihood                     /* SNR rescale_loglikelihood factor */
) {
    bayestar_init();

    ITT_TASK_BEGIN(itt_domain, itt_task_log_posterior);

    if (distance < min_distance || distance > max_distance)
        return -INFINITY;

    const double dec = asin(sin_dec);
    const double u2 = gsl_pow_2(u);
    const double complex exp_i_twopsi = exp_i(twopsi);
    const double one_by_r = 1 / distance;

    /* Compute time of arrival errors */
    double dt[nifos];
    toa_errors(dt, M_PI_2 - dec, ra, gmst, nifos, locations, epochs);

    double complex i0arg_complex_times_r = 0;
    double A = 0;

    /* Loop over detectors */
    PRAGMA_LOOP_COUNT_NIFOS
    for (unsigned int iifo = 0; iifo < nifos; iifo++)
    {
        const double complex F = antenna_factor(
            responses[iifo], ra, dec, gmst) * horizons[iifo];

        const double complex z_times_r =
             bayestar_signal_amplitude_model(F, exp_i_twopsi, u, u2);

        i0arg_complex_times_r += conj(z_times_r)
            * eval_snr(snrs[iifo], nsamples, (t - dt[iifo]) * sample_rate - 0.5 * (nsamples - 1));
        A += cabs2(z_times_r);
    }
    A *= -0.5;

    double i0arg_times_r = cabs(i0arg_complex_times_r);

    A *= gsl_pow_2(rescale_loglikelihood);
    i0arg_times_r *= gsl_pow_2(rescale_loglikelihood);

    double result = (A * one_by_r + i0arg_times_r) * one_by_r
        + log(gsl_sf_bessel_I0_scaled(i0arg_times_r * one_by_r)
                * gsl_pow_int(distance, prior_distance_power));

    if (cosmology)
        result += log_dVC_dVL(distance);

    ITT_TASK_END(itt_domain);

    return result;
}


/*
 * Unit tests
 */


static void test_cabs2(float complex z)
{
    float result = cabs2(z);
    float expected = cabsf(z);
    expected *= expected;
    gsl_test_abs(result, expected, 2 * GSL_FLT_EPSILON,
        "testing cabs2(%g + %g j)", crealf(z), cimagf(z));
}


static void test_catrom(void)
{
    for (float t = 0; t <= 1; t += 0.01f)
    {
        const float result = catrom(0, 0, 0, 0, t);
        const float expected = 0;
        gsl_test_abs(result, expected, 0,
            "testing Catmull-rom interpolant for zero input");
    }

    for (float t = 0; t <= 1; t += 0.01f)
    {
        const float result = catrom(1, 1, 1, 1, t);
        const float expected = 1;
        gsl_test_abs(result, expected, 0,
            "testing Catmull-rom interpolant for unit input");
    }

    for (float t = 0; t <= 1; t += 0.01f)
    {
        const float result = catrom(1, 0, 1, 4, t);
        const float expected = gsl_pow_2(t);
        gsl_test_abs(result, expected, 0,
            "testing Catmull-rom interpolant for quadratic real input");
    }
}


static void test_eval_snr(void)
{
    static const size_t nsamples = 64;
    float x[nsamples][2];

    /* Populate data with samples of x(t) = t^2 * exp(i * t) */
    for (size_t i = 0; i < nsamples; i ++)
    {
        x[i][0] = gsl_pow_2(i);
        x[i][1] = i;
    }

    for (float t = 0; t <= nsamples; t += 0.1)
    {
        const float complex result = eval_snr(x, nsamples, t);
        const float complex expected = (t > 1 && t < nsamples - 2) ? (gsl_pow_2(t) * exp_i(t)) : 0;
        gsl_test_abs(cabsf(result), cabsf(expected), 1e4 * GSL_FLT_EPSILON,
            "testing abs of eval_snr(%g) for x(t) = t^2 * exp(i * t)", t);
        gsl_test_abs(cargf(result), cargf(expected), 1e4 * GSL_FLT_EPSILON,
            "testing arg of eval_snr(%g) for x(t) = t^2 * exp(i * t)", t);
    }
}


static void test_log_radial_integral(
    double expected, double tol, double r1, double r2, double p2, double b, int k)
{
    const double p = sqrt(p2);
    log_radial_integrator *integrator = log_radial_integrator_init(
        r1, r2, k, 0, p + 0.5, default_log_radial_integrator_size);

    gsl_test(!integrator, "testing that integrator object is non-NULL");
    if (integrator)
    {
        const double result = log_radial_integrator_eval(integrator, p, b, log(p), log(b));

        gsl_test_rel(
            result, expected, tol,
            "testing toa_phoa_snr_log_radial_integral("
            "r1=%g, r2=%g, p2=%g, b=%g, k=%d)", r1, r2, p2, b, k);
        free(integrator);
    }
}


static void test_distance_moments_to_parameters_round_trip(double mean, double std)
{
    static const double min_mean_std = M_SQRT3 + 1e-2;
    const double mean_std = mean / std;
    double mu, sigma, norm, mean2, std2, norm2;

    bayestar_distance_moments_to_parameters(
        mean, std, &mu, &sigma, &norm);
    bayestar_distance_parameters_to_moments(
        mu, sigma, &mean2, &std2, &norm2);

    if (gsl_finite(mean_std) && mean_std >= min_mean_std)
    {
        /* Precision degrades as we approach the singularity at
         * mean/std=sqrt(3). Relax the tolerance of the test near there. */
        const double rtol = mean_std >= min_mean_std + 0.1 ? 1e-9 : 6e-5;
        gsl_test_rel(norm2, norm, rtol,
            "testing round-trip conversion of normalization for mean=%g, std=%g",
            mean, std);
        gsl_test_rel(mean2, mean, rtol,
            "testing round-trip conversion of mean for mean=%g, std=%g",
            mean, std);
        gsl_test_rel(std2, std, rtol,
            "testing round-trip conversion of std for mean=%g, std=%g",
            mean, std);
    } else {
        gsl_test_int(gsl_isinf(mu), 1,
            "testing that out-of-bounds value gives mu=+inf for mean=%g, std=%g",
            mean, std);
        gsl_test_abs(sigma, 1, 0,
            "testing that out-of-bounds value gives sigma=1 for mean=%g, std=%g",
            mean, std);
        gsl_test_abs(norm, 0, 0,
            "testing that out-of-bounds value gives norm=0 for mean=%g, std=%g",
            mean, std);
        gsl_test_int(gsl_isinf(mean2), 1,
            "testing that out-of-bounds value gives mean=+inf for mean=%g, std=%g",
            mean, std);
        gsl_test_abs(std2, 1, 0,
            "testing that out-of-bounds value gives std=1 for mean=%g, std=%g",
            mean, std);
        gsl_test_abs(norm2, 0, 0,
            "testing that out-of-bounds value gives norm=0 for mean=%g, std=%g",
            mean, std);
    }
}


static void test_nest2uniq64(uint8_t order, int64_t nest, int64_t uniq)
{
    const int64_t uniq_result = nest2uniq64(order, nest);
    gsl_test(!(uniq_result == uniq),
        "expected nest2uniq64(%u, %llu) = %llu, got %llu",
        (unsigned) order, nest, uniq, uniq_result);

    int64_t nest_result;
    const uint8_t order_result = uniq2nest64(uniq, &nest_result);
    gsl_test(!(nest_result == nest && order_result == order),
        "expected uniq2nest64(%llu) = (%u, %llu), got (%u, %llu)",
        uniq, (unsigned) order, nest, order_result, nest_result);
}


static void test_cosmology(void)
{
    static const int n = sizeof(dVC_dVL_test_x) / sizeof(*dVC_dVL_test_x);
    for (int i = 0; i < n; i ++)
    {
        const double DL = dVC_dVL_test_x[i];
        const double result = exp(log_dVC_dVL(DL));
        const double expected = dVC_dVL_test_y[i];
        gsl_test_rel(result, expected, 2e-3,
            "testing cosmological prior for DL=%g", DL);
    }
}


int bayestar_test(void)
{
    /* Initialize precalculated tables. */
    bayestar_init();

    for (double re = -1; re < 1; re += 0.1)
        for (double im = -1; im < 1; im += 0.1)
            test_cabs2(re + im * 1.0j);

    test_catrom();
    test_eval_snr();

    /* Tests of radial integrand with p2=0, b=0. */
    test_log_radial_integral(0, 0, 0, 1, 0, 0, 0);
    test_log_radial_integral(0, 0, exp(1), exp(2), 0, 0, -1);
    test_log_radial_integral(log(63), 0, 3, 6, 0, 0, 2);
    /* Te integrand with p2>0, b=0 (from Mathematica). */
    test_log_radial_integral(-0.480238, 1e-3, 1, 2, 1, 0, 0);
    test_log_radial_integral(0.432919, 1e-3, 1, 2, 1, 0, 2);
    test_log_radial_integral(-2.76076, 1e-3, 0, 1, 1, 0, 2);
    test_log_radial_integral(61.07118, 1e-3, 0, 1e9, 1, 0, 2);
    test_log_radial_integral(-112.23053, 5e-2, 0, 0.1, 1, 0, 2);
    /* Note: this test underflows, so we test that the log is -inf. */
    /* test_log_radial_integral(-1.00004e6, 1e-8, 0, 1e-3, 1, 0, 2); */
    test_log_radial_integral(-INFINITY, 1e-3, 0, 1e-3, 1, 0, 2);

    /* Tests of radial integrand with p2>0, b>0 with ML peak outside
     * of integration limits (true values from Mathematica NIntegrate). */
    test_log_radial_integral(2.94548, 1e-4, 0, 4, 1, 1, 2);
    test_log_radial_integral(2.94545, 1e-4, 0.5, 4, 1, 1, 2);
    test_log_radial_integral(2.94085, 1e-4, 1, 4, 1, 1, 2);
    /* Tests of radial integrand with p2>0, b>0 with ML peak outside
     * of integration limits (true values from Mathematica NIntegrate). */
    test_log_radial_integral(-2.43264, 1e-5, 0, 1, 1, 1, 2);
    test_log_radial_integral(-2.43808, 1e-5, 0.5, 1, 1, 1, 2);
    test_log_radial_integral(-0.707038, 1e-5, 1, 1.5, 1, 1, 2);

    {
        const double r1 = 0.0, r2 = 0.25, pmax = 1.0;
        const int k = 2;
        const double tol = 1e-5;
        log_radial_integrator *integrator = log_radial_integrator_init(
            r1, r2, k, 0, pmax, default_log_radial_integrator_size);

        gsl_test(!integrator, "testing that integrator object is non-NULL");
        if (integrator)
        {
            for (double p = 0.01; p <= pmax; p += 0.01)
            {
                for (double b = 0.0; b <= 2 * pmax; b += 0.01)
                {
                    const double r0 = 2 * gsl_pow_2(p) / b;
                    const double x = log(p);
                    const double y = log(r0);
                    const double expected = exp(log_radial_integral(r1, r2, p, b, k, 0));
                    const double result = exp(log_radial_integrator_eval(integrator, p, b, log(p), log(b)) - gsl_pow_2(0.5 * b / p));
                    gsl_test_abs(
                        result, expected, tol, "testing log_radial_integrator_eval("
                        "r1=%g, r2=%g, p=%g, b=%g, k=%d, x=%g, y=%g)", r1, r2, p, b, k, x, y);
                }
            }
            free(integrator);
        }
    }

    for (double mean = 0; mean < 100; mean ++)
        for (double std = 0; std < 100; std ++)
            test_distance_moments_to_parameters_round_trip(mean, std);

    test_nest2uniq64(0, 0, 4);
    test_nest2uniq64(0, 1, 5);
    test_nest2uniq64(0, 2, 6);
    test_nest2uniq64(0, 3, 7);
    test_nest2uniq64(0, 4, 8);
    test_nest2uniq64(0, 5, 9);
    test_nest2uniq64(0, 6, 10);
    test_nest2uniq64(0, 7, 11);
    test_nest2uniq64(0, 8, 12);
    test_nest2uniq64(0, 9, 13);
    test_nest2uniq64(0, 10, 14);
    test_nest2uniq64(0, 11, 15);
    test_nest2uniq64(1, 0, 16);
    test_nest2uniq64(1, 1, 17);
    test_nest2uniq64(1, 2, 18);
    test_nest2uniq64(1, 47, 63);
    test_nest2uniq64(12, 0, 0x4000000ull);
    test_nest2uniq64(12, 1, 0x4000001ull);
    test_nest2uniq64(29, 0, 0x1000000000000000ull);
    test_nest2uniq64(29, 1, 0x1000000000000001ull);
    test_nest2uniq64(29, 0x2FFFFFFFFFFFFFFFull, 0x3FFFFFFFFFFFFFFFull);

    test_cosmology();

    return gsl_test_summary();
}

src/cubic_interp.c

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
/*
 * Copyright (C) 2015-2020  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */


#include "cubic_interp.h"
#include "branch_prediction.h"
#include "vmath.h"
#include <math.h>
#include <stdalign.h>
#include <stdlib.h>
#include <string.h>

/* Allow contraction of a * b + c to a faster fused multiply-add operation.
 * This pragma is supposedly standard C, but only clang seems to support it.
 * On other compilers, floating point contraction is ON by default at -O3. */
#if defined(__clang__) || defined(__llvm__)
#pragma STDC FP_CONTRACT ON
#endif

#define VCLIP(x, a, b) VMIN(VMAX((x), (a)), (b))
#define VCUBIC(a, t) (t * (t * (t * a[0] + a[1]) + a[2]) + a[3])


struct cubic_interp{
    double f, t0, length;
    double a[][4];
};


struct bicubic_interp {
    v2df fx, x0, xlength;
    v4df a[][4];
};


/*
 * Calculate coefficients of the interpolating polynomial in the form
 *      a[0] * t^3 + a[1] * t^2 + a[2] * t + a[3]
 */
static void cubic_interp_init_coefficients(
    double *a, const double *z, const double *z1)
{
    if (UNLIKELY(!isfinite(z1[1] + z1[2])))
    {
        /* If either of the inner grid points are NaN or infinite,
         * then fall back to nearest-neighbor interpolation. */
        a[0] = 0;
        a[1] = 0;
        a[2] = 0;
        a[3] = z[1];
    } else if (UNLIKELY(!isfinite(z1[0] + z1[3]))) {
        /* If either of the outer grid points are NaN or infinite,
         * then fall back to linear interpolation. */
        a[0] = 0;
        a[1] = 0;
        a[2] = z[2] - z[1];
        a[3] = z[1];
    } else {
        /* Otherwise, all of the grid points are finite.
         * Use cubic interpolation. */
        a[0] = 1.5 * (z[1] - z[2]) + 0.5 * (z[3] - z[0]);
        a[1] = z[0] - 2.5 * z[1] + 2 * z[2] - 0.5 * z[3];
        a[2] = 0.5 * (z[2] - z[0]);
        a[3] = z[1];
    }
}


cubic_interp *cubic_interp_init(
    const double *data, int n, double tmin, double dt)
{
    cubic_interp *interp;
    const int length = n + 6;
    interp = malloc(sizeof(*interp) + length * sizeof(*interp->a));
    if (LIKELY(interp))
    {
        interp->f = 1 / dt;
        interp->t0 = 3 - interp->f * tmin;
        interp->length = length;
        for (int i = 0; i < length; i ++)
        {
            double z[4];
            for (int j = 0; j < 4; j ++)
            {
                z[j] = data[VCLIP(i + j - 4, 0, n - 1)];
            }
            cubic_interp_init_coefficients(interp->a[i], z, z);
        }
    }
    return interp;
}


void cubic_interp_free(cubic_interp *interp)
{
    free(interp);
}


double cubic_interp_eval(const cubic_interp *interp, double t)
{
    if (UNLIKELY(isnan(t)))
        return t;

    double x = t, xmin = 0.0, xmax = interp->length - 1.0;
    x *= interp->f;
    x += interp->t0;
    x = VCLIP(x, xmin, xmax);

    double ix = VFLOOR(x);
    x -= ix;

    const double *a = interp->a[(int) ix];
    return VCUBIC(a, x);
}


bicubic_interp *bicubic_interp_init(
    const double *data, int ns, int nt,
    double smin, double tmin, double ds, double dt)
{
    bicubic_interp *interp = NULL;
    const int slength = ns + 6;
    const int tlength = nt + 6;
    interp = aligned_alloc(
        alignof(bicubic_interp),
        sizeof(*interp) + slength * tlength * sizeof(*interp->a));
    if (LIKELY(interp))
    {
        interp->fx[0] = 1 / ds;
        interp->fx[1] = 1 / dt;
        interp->x0[0] = 3 - interp->fx[0] * smin;
        interp->x0[1] = 3 - interp->fx[1] * tmin;
        interp->xlength[0] = slength;
        interp->xlength[1] = tlength;

        for (int is = 0; is < slength; is ++)
        {
            for (int it = 0; it < tlength; it ++)
            {
                double a[4][4], a1[4][4];
                for (int js = 0; js < 4; js ++)
                {
                    double z[4];
                    int ks = VCLIP(is + js - 4, 0, ns - 1);
                    for (int jt = 0; jt < 4; jt ++)
                    {
                        int kt = VCLIP(it + jt - 4, 0, nt - 1);
                        z[jt] = data[ks * ns + kt];
                    }
                    cubic_interp_init_coefficients(a[js], z, z);
                }
                for (int js = 0; js < 4; js ++)
                {
                    for (int jt = 0; jt < 4; jt ++)
                    {
                        a1[js][jt] = a[jt][js];
                    }
                }
                for (int js = 0; js < 4; js ++)
                {
                    cubic_interp_init_coefficients(a[js], a1[js], a1[3]);
                }
                memcpy(interp->a[is * slength + it], a, sizeof(a));
            }
        }
    }
    return interp;
}


void bicubic_interp_free(bicubic_interp *interp)
{
    free(interp);
}


double bicubic_interp_eval(const bicubic_interp *interp, double s, double t)
{
    if (UNLIKELY(isnan(s) || isnan(t)))
        return s + t;

    v2df x = {s, t}, xmin = {0.0, 0.0}, xmax = interp->xlength - 1.0;
    x *= interp->fx;
    x += interp->x0;
    x = VCLIP(x, xmin, xmax);

    v2df ix = VFLOOR(x);
    x -= ix;

    const v4df *a = interp->a[(int) (ix[0] * interp->xlength[0] + ix[1])];
    v4df b = VCUBIC(a, x[1]);
    return VCUBIC(b, x[0]);
}

src/cubic_interp_test.c

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
/*
 * Copyright (C) 2015-2017  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */


#include "cubic_interp.h"
#include <gsl/gsl_test.h>
#include <gsl/gsl_math.h>
#include <assert.h>


int cubic_interp_test(void)
{
    {
        static const double data[] = {0, 0, 0, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = -10; t <= 10; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = 0;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for zero input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {1, 1, 1, 1};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = -10; t <= 10; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = 1;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for unit input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {1, 0, 1, 4};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = gsl_pow_2(t);
            gsl_test_abs(result, expected, 10 * GSL_DBL_EPSILON,
                "testing cubic interpolant for quadratic input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            GSL_POSINF, GSL_POSINF, GSL_POSINF, GSL_POSINF};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_POSINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for +inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, GSL_POSINF, GSL_POSINF, GSL_POSINF};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_POSINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for +inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            GSL_POSINF, GSL_POSINF, GSL_POSINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_POSINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for +inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, GSL_POSINF, GSL_POSINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_POSINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for +inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, 0, GSL_POSINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0.01; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = 0;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for +inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, GSL_NEGINF, GSL_POSINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        const double result = cubic_interp_eval(interp, 1);
        cubic_interp_free(interp);
        const double expected = GSL_POSINF;
        gsl_test_abs(result, expected, 0,
            "testing cubic interpolant for +inf input");
    }

    {
        static const double data[] = {
            0, GSL_POSINF, GSL_NEGINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        const double result = cubic_interp_eval(interp, 0);
        cubic_interp_free(interp);
        const double expected = GSL_POSINF;
        gsl_test_abs(result, expected, 0,
            "testing cubic interpolant for +inf input");
    }

    {
        static const double data[] = {
            0, GSL_NEGINF, GSL_NEGINF, GSL_NEGINF};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_NEGINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for -inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            GSL_NEGINF, GSL_NEGINF, GSL_NEGINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_NEGINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for -inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, GSL_NEGINF, GSL_NEGINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_NEGINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for -inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, 0, GSL_NEGINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0.01; t <= 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = 0;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for -inf input");
        }
        cubic_interp_free(interp);
    }

    {
        static const double data[] = {
            0, GSL_NEGINF, GSL_POSINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        const double result = cubic_interp_eval(interp, 0);
        cubic_interp_free(interp);
        const double expected = GSL_NEGINF;
        gsl_test_abs(result, expected, 0,
            "testing cubic interpolant for -inf input");
    }

    {
        static const double data[] = {
            0, GSL_POSINF, GSL_NEGINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        const double result = cubic_interp_eval(interp, 1);
        cubic_interp_free(interp);
        const double expected = GSL_NEGINF;
        gsl_test_abs(result, expected, 0,
            "testing cubic interpolant for -inf input");
    }

    {
        static const double data[] = {
            0, GSL_NEGINF, GSL_POSINF, 0};
        cubic_interp *interp = cubic_interp_init(data, 4, -1, 1);
        assert(interp);
        for (double t = 0.01; t < 1; t += 0.01)
        {
            const double result = cubic_interp_eval(interp, t);
            const double expected = GSL_NEGINF;
            gsl_test_abs(result, expected, 0,
                "testing cubic interpolant for indeterminate input");
        }
        cubic_interp_free(interp);
    }


    {
        static const double constants[] = {
            0, 1, GSL_POSINF, GSL_NEGINF, GSL_NAN};
        for (unsigned k = 0; k < sizeof(constants) / sizeof(*constants); k ++)
        {
            double data[4][4];
            for (int i = 0; i < 4; i ++)
                for (int j = 0; j < 4; j ++)
                    data[i][j] = constants[k];
            bicubic_interp *interp = bicubic_interp_init(
                *data, 4, 4, -1, -1, 1, 1);
            for (double s = -5; s <= 2; s += 0.1)
            {
                for (double t = -5; t <= 1; t += 0.1)
                {
                    const double result = bicubic_interp_eval(interp, s, t);
                    const double expected = constants[k];
                    gsl_test_abs(result, expected, 0,
                        "testing bicubic interpolant for constant %g input",
                        constants[k]);
                }
            }
            assert(interp);
            bicubic_interp_free(interp);
        }
    }

    for (int k = 1; k < 3; k ++)
    {
        {
            double data[4][4];
            for (int i = 0; i < 4; i ++)
                for (int j = 0; j < 4; j ++)
                    data[i][j] = gsl_pow_int(i - 1, k);
            bicubic_interp *interp = bicubic_interp_init(
                *data, 4, 4, -1, -1, 1, 1);
            for (double s = 0; s <= 1; s += 0.1)
            {
                for (double t = 0; t <= 1; t += 0.1)
                {
                    const double result = bicubic_interp_eval(interp, s, t);
                    const double expected = gsl_pow_int(s, k);
                    gsl_test_abs(result, expected, 10 * GSL_DBL_EPSILON,
                        "testing bicubic interpolant for s^%d input", k);
                }
            }
            assert(interp);
            bicubic_interp_free(interp);
        }

        {
            double data[4][4];
            for (int i = 0; i < 4; i ++)
                for (int j = 0; j < 4; j ++)
                    data[i][j] = gsl_pow_int(j - 1, k);
            bicubic_interp *interp = bicubic_interp_init(
                *data, 4, 4, -1, -1, 1, 1);
            for (double s = 0; s <= 1; s += 0.1)
            {
                for (double t = 0; t <= 1; t += 0.1)
                {
                    const double result = bicubic_interp_eval(interp, s, t);
                    const double expected = gsl_pow_int(t, k);
                    gsl_test_abs(result, expected, 10 * GSL_DBL_EPSILON,
                        "testing bicubic interpolant for t^%d input", k);
                }
            }
            assert(interp);
            bicubic_interp_free(interp);
        }

        {
            double data[4][4];
            for (int i = 0; i < 4; i ++)
                for (int j = 0; j < 4; j ++)
                    data[i][j] = gsl_pow_int(i - 1, k) + gsl_pow_int(j - 1, k);
            bicubic_interp *interp = bicubic_interp_init(
                *data, 4, 4, -1, -1, 1, 1);
            for (double s = 0; s <= 1; s += 0.1)
            {
                for (double t = 0; t <= 1; t += 0.1)
                {
                    const double result = bicubic_interp_eval(interp, s, t);
                    const double expected = gsl_pow_int(s, k)
                                          + gsl_pow_int(t, k);
                    gsl_test_abs(result, expected, 10 * GSL_DBL_EPSILON,
                        "testing bicubic interpolant for s^%d + t^%d input",
                        k, k);
                }
            }
            gsl_test(!gsl_isnan(bicubic_interp_eval(interp, 0, GSL_NAN)),
                "testing that bicubic interpolant for nan input returns nan");
            gsl_test(!gsl_isnan(bicubic_interp_eval(interp, GSL_NAN, 0)),
                "testing that bicubic interpolant for nan input returns nan");
            assert(interp);
            bicubic_interp_free(interp);
        }
    }

    return gsl_test_summary();
}

src/omp_interruptible.h

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
/*
 * Copyright (C) 2017  Leo Singer
 *
 * These preprocessor macros help make long-running Python C extensions,
 * possibly that contain OpenMP parallel for loops, respond gracefully to
 * signals.
 *
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */


/*
 * In a normal C program that does not mess with signals, when the user
 * types Ctrl-C, the process is sent the SIGINT signal. The default SIGINT
 * signal handler terminates the process swiftly, even if the program is in
 * the middle of a CPU-intensive loop, even an OpenMP parallel for loop.
 *
 * It's different for a Python program. Python itself attaches handlers for
 * most (all?) signals. Python's SIGINT handler sets a flag to remind itself to
 * raise a KeyboardInterrupt exception on the main thread before interpreting
 * the next instruction.
 *
 * If Python is in the middle of executing a long-running method in a Python C
 * extension, then the interpreter will remain unresponsive until the method
 * has returned, when it can raise the KeyboardInterrupt. This delay can be
 * very annoying to the user.
 *
 * This header provides a few macros to temporarily change the SIGINT handler
 * and provide a flag that C functions can check periodically to terminate
 * early. Here's a skeleton code sample that includes an OpenMP loop to show
 * how to use the macros. We start with our basic function, foo, which contains
 * an OpenMP parallel for loop:
 *
 *      int foo(int n)
 *      {
 *          int retval = 0;
 *          #pragma omp parallel for
 *          for (int i = 0; i < n; i ++)
 *          {
 *              ... // The actual work occurs here.
 *          }
 *          return retval;
 *      }
 *
 * We add the macros:
 *
 *      #include "omp_interruptible.h"
 *
 *      // Native C function that does the work.
 *      int foo(int n)
 *      {
 *          int retval = 0;
 *          OMP_BEGIN_INTERRUPTIBLE  // Replace SIGINT handler.
 *          #pragma omp parallel for
 *          for (int i = 0; i < n; i ++)
 *          {
 *              // Exit loop early if SIGINT has fired.
 *              // Note: you can replace OMP_EXIT_LOOP_EARLY with a simple
 *              // `break;` statement or check it in the loop conditional,
 *              // if you are not using an OpenMP loop.
 *              if (OMP_WAS_INTERRUPTED)
 *                  OMP_EXIT_LOOP_EARLY
 *              ...  // The actual work occurs here.
 *          }
 *          if (OMP_WAS_INTERRUPTED)
 *              retval = -1;
 *          OMP_END_INTERRUPTIBLE  // Restore SIGINT handler.
 *          return retval;
 *      }
 *
 * Finally, here's the Python C extension:
 *
 *      #include <Python.h>
 *
 *      static PyObject *mod_foo(PyObject *module, PyObject *args)
 *      {
 *          int reval;
 *
 *          // Run the underlying C function, releasing the global interpreter
 *          // lock (GIL) in the mean time so that other Python threads (if
 *          // any) can run.
 *          Py_BEGIN_ALLOW_THREADS
 *          int retval = foo(1000);
 *          Py_END_ALLOW_THREADS
 *
 *          // Important: call PyErr_CheckSignals() to give Python a chance to
 *          // raise a KeyboardInterrupt exception, if needed.
 *
 *          // Indicate success or failure of the method to the interpreter.
 *          PyErr_CheckSignals();
 *          if (retval == 0)
 *              Py_RETURN_NONE;
 *          else
 *              return NULL;
 *      }
 *
 *      static PyMethodDef methods[] = {
 *          {"foo", (PyCFunction)mod_foo, METH_NOARGS, "doc string here"},
 *          {NULL, NULL, 0, NULL}
 *      };
 *
 *       static PyModuleDef moduledef = {
 *           PyModuleDef_HEAD_INIT,
 *           "mod", NULL, -1, methods,
 *           NULL, NULL, NULL, NULL
 *       };
 *
 *      PyMODINIT_FUNC PyInit_mod(void)
 *      {
 *          return PyModule_Create(&moduledef);
 *      }
 */


#ifndef OMP_INTERRUPTIBLE_H
#define OMP_INTERRUPTIBLE_H

#include "branch_prediction.h"

#include <signal.h>
#include <stdlib.h>


/* This is a per-thread pointer to an integer flag that we set when we
 * receive SIGINT. This is a pointer, and not the flag itself, because the
 * flag should be shared among all OpenMP threads, and therefore cannot
 * itself be thread-local. */
static __thread int *omp_interruptible_flag_ptr = NULL;


static __thread struct sigaction omp_interruptible_old_action = {
    .sa_handler = NULL
};


static void omp_interruptible_restore_handler(int sig)
{
    int ret = sigaction(sig, &omp_interruptible_old_action, NULL);
    (void)ret; /* FIXME: should probably do something with this return value */
    omp_interruptible_old_action = (struct sigaction) {.sa_handler = NULL};
    omp_interruptible_flag_ptr = NULL;
}


static void omp_interruptible_handler(int sig)
{
    *omp_interruptible_flag_ptr = 1;
    #pragma omp flush
    omp_interruptible_restore_handler(sig);
    raise(sig);
}


static const struct sigaction omp_interruptible_action = {
    .sa_handler = omp_interruptible_handler
};


static void omp_interruptible_set_handler(int sig, int *flag_ptr)
{
    omp_interruptible_flag_ptr = flag_ptr;
    *omp_interruptible_flag_ptr = 0;
    int ret = sigaction(
        sig, &omp_interruptible_action, &omp_interruptible_old_action);
    (void)ret; /* FIXME: should probably do something with this return value */
}


#define OMP_BEGIN_INTERRUPTIBLE { \
    int omp_was_interrupted; \
    omp_interruptible_set_handler(SIGINT, &omp_was_interrupted);


#define OMP_END_INTERRUPTIBLE \
    omp_interruptible_restore_handler(SIGINT); \
}


#define OMP_WAS_INTERRUPTED UNLIKELY(omp_was_interrupted)


#if _OPENMP
#define OMP_EXIT_LOOP_EARLY continue;
#else
#define OMP_EXIT_LOOP_EARLY break;
#endif


#endif /* OMP_INTERRUPTIBLE_H */

src/vmath.h

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
/*
 * Copyright (C) 2020  Leo Singer
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#ifndef VMATH_H
#define VMATH_H

#ifndef __cplusplus

#include <math.h>


/* Vector types (gcc/clang/icc vector extension to the C language) */

typedef double v2df __attribute__ ((vector_size (2 * sizeof(double))));
typedef double v4df __attribute__ ((vector_size (4 * sizeof(double))));


/* Vectorized math functions using x86-64 intrinsics if available */

#ifdef __x86_64__
#include <immintrin.h>
#endif

#define V2DF_BINARY_OP(func, scalarfunc) \
static v2df v2df_ ## func(v2df a, v2df b) \
{ \
    v2df result; \
    for (int i = 0; i < 2; i ++) \
        result[i] = scalarfunc(a[i], b[i]); \
    return result; \
}

#define V2DF_UNARY_OP(func, scalarfunc) \
static v2df v2df_ ## func(v2df a) \
{ \
    v2df result; \
    for (int i = 0; i < 2; i ++) \
        result[i] = scalarfunc(a[i]); \
    return result; \
}

#ifdef __SSE2__
static v2df v2df_min(v2df a, v2df b) { return _mm_min_pd(a, b); }
static v2df v2df_max(v2df a, v2df b) { return _mm_max_pd(a, b); }
#else
V2DF_BINARY_OP(min, fmin)
V2DF_BINARY_OP(max, fmax)
#endif

#ifdef __SSE4_1__
static v2df v2df_floor(v2df a) { return _mm_floor_pd(a); }
#else
V2DF_UNARY_OP(floor, floor)
#endif


/* C11 generics for selected math functions */

static int int_min(int a, int b)
{
    return a < b ? a : b;
}

static int int_max(int a, int b)
{
    return a > b ? a : b;
}

#define VMIN(a, b) _Generic((a), \
    v2df: v2df_min, \
    int: int_min, \
    double: fmin \
)((a), (b))

#define VMAX(a, b) _Generic((a), \
    v2df: v2df_max, \
    int: int_max, \
    double: fmax \
)((a), (b))

#define VFLOOR(a) _Generic((a), \
    v2df: v2df_floor, \
    double: floor \
)(a)


#endif /* __cplusplus */

#endif /* VMATH_H */