pycobertura report

Filename Stmts Miss Cover Missing
ligo/skymap/__init__.py 11 0 100.00%
ligo/skymap/distance.py 99 16 83.84% 470-483, 496, 645
ligo/skymap/healpix_tree.py 126 42 66.67% 95, 100-101, 110, 163-167, 185-212, 218-241, 292-304
ligo/skymap/kde.py 223 44 80.27% 72-77, 82-116, 134, 154-155, 174, 221, 225, 233-234, 272-276, 389, 403-404, 485, 491, 533
ligo/skymap/moc.py 37 0 100.00%
ligo/skymap/bayestar/__init__.py 231 45 80.52% 71, 80-85, 92-108, 120, 168, 221, 226, 241, 251, 287, 306, 382-387, 405-406, 493-508
ligo/skymap/bayestar/ez_emcee.py 36 30 16.67% 25, 107-156
ligo/skymap/bayestar/filter.py 182 10 94.51% 86, 235, 253-258, 265, 404, 442, 446
ligo/skymap/bayestar/interpolation.py 118 7 94.07% 208-214
ligo/skymap/bayestar/ptemcee.py 27 19 29.63% 12-26, 41-49, 54-57
ligo/skymap/coordinates/__init__.py 8 0 100.00%
ligo/skymap/coordinates/detector.py 27 0 100.00%
ligo/skymap/coordinates/eigenframe.py 28 5 82.14% 98-101, 113
ligo/skymap/io/__init__.py 8 0 100.00%
ligo/skymap/io/fits.py 148 18 87.84% 441, 446-447, 456, 493-499, 522-537
ligo/skymap/io/hdf5.py 68 2 97.06% 298-299
ligo/skymap/io/events/__init__.py 8 0 100.00%
ligo/skymap/io/events/base.py 53 10 81.13% 39-40, 46-53, 140, 147
ligo/skymap/io/events/detector_disabled.py 32 0 100.00%
ligo/skymap/io/events/gracedb.py 20 0 100.00%
ligo/skymap/io/events/hdf.py 129 3 97.67% 59, 228-229
ligo/skymap/io/events/ligolw.py 162 14 91.36% 54, 67, 73-74, 129-130, 145-146, 180-181, 191, 212, 220, 243
ligo/skymap/io/events/magic.py 34 0 100.00%
ligo/skymap/io/events/sqlite.py 20 0 100.00%
ligo/skymap/plot/__init__.py 8 0 100.00%
ligo/skymap/plot/allsky.py 195 33 83.08% 173-174, 181-182, 192-201, 260, 287, 301-304, 340, 352-360, 420-422, 488-490, 561, 586
ligo/skymap/plot/angle.py 13 4 69.23% 26-27, 38, 43
ligo/skymap/plot/backdrop.py 40 24 40.00% 41-44, 84-86, 126-146, 184-205, 209-211
ligo/skymap/plot/cmap.py 18 0 100.00%
ligo/skymap/plot/cylon.py 1 0 100.00%
ligo/skymap/plot/itrs_frame_monkeypatch.py 11 0 100.00%
ligo/skymap/plot/marker.py 21 0 100.00%
ligo/skymap/plot/poly.py 59 50 15.25% 32-42, 53-58, 69-157, 165-184
ligo/skymap/plot/pp.py 81 8 90.12% 98-102, 164, 167-168, 283
ligo/skymap/plot/util.py 31 2 93.55% 41, 49
ligo/skymap/postprocess/__init__.py 8 0 100.00%
ligo/skymap/postprocess/contour.py 61 4 93.44% 72-75
ligo/skymap/postprocess/cosmology.py 24 3 87.50% 55-58
ligo/skymap/postprocess/crossmatch.py 149 7 95.30% 294, 356, 379, 415, 439-442
ligo/skymap/postprocess/ellipse.py 54 10 81.48% 317-323, 347, 364, 371
ligo/skymap/postprocess/util.py 36 16 55.56% 55, 81-84, 88-94, 98-101
ligo/skymap/tool/__init__.py 193 48 75.13% 46-49, 63, 76-79, 93-96, 99, 102-109, 111-115, 193-196, 314-320, 389, 403, 425-446
ligo/skymap/tool/bayestar_inject.py 214 119 44.39% 268-459
ligo/skymap/tool/bayestar_localize_coincs.py 66 19 71.21% 113, 119-149, 152, 162, 178-182, 190
ligo/skymap/tool/bayestar_localize_lvalert.py 81 19 76.54% 107, 118, 125-127, 131-133, 151, 164, 168-178, 182-183
ligo/skymap/tool/bayestar_mcmc.py 76 59 22.37% 67, 73-199
ligo/skymap/tool/bayestar_realize_coincs.py 183 5 97.27% 314-315, 353, 387, 392
ligo/skymap/tool/bayestar_sample_model_psd.py 53 0 100.00%
ligo/skymap/tool/ligo_skymap_combine.py 74 4 94.59% 69, 95, 101, 106
ligo/skymap/tool/ligo_skymap_constellations.py 27 18 33.33% 42-61
ligo/skymap/tool/ligo_skymap_contour.py 31 18 41.94% 53-81
ligo/skymap/tool/ligo_skymap_contour_moc.py 25 14 44.00% 47-76
ligo/skymap/tool/ligo_skymap_flatten.py 34 2 94.12% 56-57
ligo/skymap/tool/ligo_skymap_from_samples.py 94 13 86.17% 102-105, 108-109, 127, 133, 140-141, 146, 166-167, 172
ligo/skymap/tool/ligo_skymap_plot.py 85 5 94.12% 144-155, 170-171
ligo/skymap/tool/ligo_skymap_plot_airmass.py 99 76 23.23% 64, 68, 76-218
ligo/skymap/tool/ligo_skymap_plot_observability.py 73 49 32.88% 67, 75-148
ligo/skymap/tool/ligo_skymap_plot_pp_samples.py 59 44 25.42% 29-35, 62-128
ligo/skymap/tool/ligo_skymap_plot_stats.py 126 3 97.62% 71, 136, 138
ligo/skymap/tool/ligo_skymap_plot_volume.py 122 10 91.80% 88, 94, 98-99, 122, 163, 235-236, 242-243
ligo/skymap/tool/ligo_skymap_stats.py 85 9 89.41% 145-146, 151, 168, 172, 174, 190, 192, 224
ligo/skymap/tool/ligo_skymap_unflatten.py 20 11 45.00% 36-48
ligo/skymap/tool/matplotlib.py 69 6 91.30% 35-36, 48-49, 72-73
ligo/skymap/util/__init__.py 8 0 100.00%
ligo/skymap/util/file.py 25 12 52.00% 30-42
ligo/skymap/util/ilwd.py 45 5 88.89% 65-68, 124
ligo/skymap/util/numpy.py 9 0 100.00%
ligo/skymap/util/progress.py 37 2 94.59% 37-38
ligo/skymap/util/sqlite.py 34 1 97.06% 152
ligo/skymap/util/stopwatch.py 47 12 74.47% 38-41, 44, 54, 57-59, 67, 86-88
src/bayestar_distance.c 252 6 97.62% 93-94, 126-129, 455-456
src/bayestar_moc.c 52 1 98.08% 122
src/bayestar_sky_map.c 631 64 89.86% 253, 269, 335, 452-456, 622, 637-638, 677, 878, 899, 924, 931-933, 964-980, 1009-1010, 1020-1021, 1069, 1102, 1129-1130, 1192-1261
src/cubic_interp.c 84 0 100.00%
src/cubic_interp_test.c 187 0 100.00%
src/omp_interruptible.h 13 5 61.54% 157-162
src/vmath.h 5 0 100.00%
TOTAL 5963 1085 81.80%

ligo/skymap/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
# This file is adapted from the Astropy package template, which is licensed
# under a 3-clause BSD style license - see licenses/TEMPLATE_LICENSE.rst

# Packages may add whatever they like to this file, but
# should keep this content at the top.
# ----------------------------------------------------------------------------
from ._astropy_init import *   # noqa
# ----------------------------------------------------------------------------

__all__ = ('omp',)


class Omp:
    """OpenMP runtime settings.

    Attributes
    ----------
    num_threads : int
        Adjust the number of OpenMP threads. Getting and setting this attribute
        call :man:`omp_get_num_threads` and :man:`omp_set_num_threads`
        respectively.

    """

    @property
    def num_threads(self):
        from .core import get_num_threads
        return get_num_threads()

    @num_threads.setter
    def num_threads(self, value):
        from .core import set_num_threads
        set_num_threads(value)


omp = Omp()
del Omp

ligo/skymap/distance.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
566  
567  
568  
569  
570  
571  
572  
573  
574  
575  
576  
577  
578  
579  
580  
581  
582  
583  
584  
585  
586  
587  
588  
589  
590  
591  
592  
593  
594  
595  
596  
597  
598  
599  
600  
601  
602  
603  
604  
605  
606  
607  
608  
609  
610  
611  
612  
613  
614  
615  
616  
617  
618  
619  
620  
621  
622  
623  
624  
625  
626  
627  
628  
629  
630  
631  
632  
633  
634  
635  
636  
637  
638  
639  
640  
641  
642  
643  
644  
645  
646  
647  
648  
649  
650  
651  
652  
653  
654  
655  
656  
657  
658  
659  
660  
661  
662  
663  
664  
665  
666  
667  
668  
669  
670  
671  
672  
673  
674  
675  
676  
677  
678  
679  
680  
#
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Distance distribution functions from [1]_, [2]_, [3]_.

References
----------
.. [1] Singer, Chen, & Holz, 2016. "Going the Distance: Mapping Host Galaxies
   of LIGO and Virgo Sources in Three Dimensions Using Local Cosmography and
   Targeted Follow-up." ApJL, 829, L15.
   :doi:`10.3847/2041-8205/829/1/L15`

.. [2] Singer, Chen, & Holz, 2016. "Supplement: 'Going the Distance: Mapping
   Host Galaxies of LIGO and Virgo Sources in Three Dimensions Using Local
   Cosmography and Targeted Follow-up' (2016, ApJL, 829, L15)." ApJS, 226, 10.
   :doi:`10.3847/0067-0049/226/1/10`

.. [3] https://asd.gsfc.nasa.gov/Leo.Singer/going-the-distance

"""

import astropy_healpix as ah
import numpy as np
import healpy as hp
import scipy.special
from .core import (conditional_pdf, conditional_cdf, conditional_ppf,
                   moments_to_parameters, parameters_to_moments, volume_render,
                   marginal_pdf, marginal_cdf, marginal_ppf)
from .util.numpy import add_newdoc_ufunc, require_contiguous_aligned

__all__ = ('conditional_pdf', 'conditional_cdf', 'conditional_ppf',
           'moments_to_parameters', 'parameters_to_moments', 'volume_render',
           'marginal_pdf', 'marginal_cdf', 'marginal_ppf', 'ud_grade',
           'conditional_kde', 'cartesian_kde_to_moments', 'principal_axes',
           'parameters_to_moments')


add_newdoc_ufunc(conditional_pdf, """\
Conditional distance probability density function (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
pdf : `numpy.ndarray`
    Conditional probability density according to ansatz.

""")
conditional_pdf = require_contiguous_aligned(conditional_pdf)


add_newdoc_ufunc(conditional_cdf, """\
Cumulative conditional distribution of distance (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
pdf : `numpy.ndarray`
    Conditional probability density according to ansatz.

Examples
--------
Test against numerical integral of pdf.

>>> import scipy.integrate
>>> distmu = 10.0
>>> distsigma = 5.0
>>> distnorm = 1.0
>>> r = 8.0
>>> expected, _ = scipy.integrate.quad(
...     conditional_pdf, 0, r,
...     (distmu, distsigma, distnorm))
>>> result = conditional_cdf(
...     r, distmu, distsigma, distnorm)
>>> np.testing.assert_almost_equal(result, expected)

""")
conditional_cdf = require_contiguous_aligned(conditional_cdf)


add_newdoc_ufunc(conditional_ppf, """\
Point percent function (inverse cdf) of distribution of distance (ansatz).

Parameters
----------
p : `numpy.ndarray`
    The cumulative distribution function
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
r : `numpy.ndarray`
    Distance at which the cdf is equal to `p`.

Examples
--------
Test against numerical estimate.

>>> import scipy.optimize
>>> distmu = 10.0
>>> distsigma = 5.0
>>> distnorm = 1.0
>>> p = 0.16  # "one-sigma" lower limit
>>> expected_r16 = scipy.optimize.brentq(
... lambda r: conditional_cdf(r, distmu, distsigma, distnorm) - p, 0.0, 100.0)
>>> r16 = conditional_ppf(p, distmu, distsigma, distnorm)
>>> np.testing.assert_almost_equal(r16, expected_r16)

""")
conditional_ppf = require_contiguous_aligned(conditional_ppf)


add_newdoc_ufunc(moments_to_parameters, """\
Convert ansatz moments to parameters.

This function is the inverse of `parameters_to_moments`.

Parameters
----------
distmean : `numpy.ndarray`
    Conditional mean of distance (Mpc)
diststd : `numpy.ndarray`
    Conditional standard deviation of distance (Mpc)

Returns
-------
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

""")
moments_to_parameters = require_contiguous_aligned(moments_to_parameters)


add_newdoc_ufunc(parameters_to_moments, """\
Convert ansatz parameters to moments.

This function is the inverse of `moments_to_parameters`.

Parameters
----------
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)

Returns
-------
distmean : `numpy.ndarray`
    Conditional mean of distance (Mpc)
diststd : `numpy.ndarray`
    Conditional standard deviation of distance (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Examples
--------
For mu=0, sigma=1, the ansatz is a chi distribution with 3 degrees of
freedom, and the moments have simple expressions.

>>> mean, std, norm = parameters_to_moments(0, 1)
>>> expected_mean = 2 * np.sqrt(2 / np.pi)
>>> expected_std = np.sqrt(3 - expected_mean**2)
>>> expected_norm = 2.0
>>> np.testing.assert_allclose(mean, expected_mean)
>>> np.testing.assert_allclose(std, expected_std)
>>> np.testing.assert_allclose(norm, expected_norm)

Check that the moments scale as expected when we vary sigma.

>>> sigma = np.logspace(-8, 8)
>>> mean, std, norm = parameters_to_moments(0, sigma)
>>> np.testing.assert_allclose(mean, expected_mean * sigma)
>>> np.testing.assert_allclose(std, expected_std * sigma)
>>> np.testing.assert_allclose(norm, expected_norm / sigma**2)

Check some more arbitrary values using numerical quadrature:

>>> import scipy.integrate
>>> sigma = 1.0
>>> for mu in np.linspace(-10, 10):
...     mean, std, norm = parameters_to_moments(mu, sigma)
...     moments = np.empty(3)
...     for k in range(3):
...         moments[k], _ = scipy.integrate.quad(
...             lambda r: r**k * conditional_pdf(r, mu, sigma, 1.0),
...             0, np.inf)
...     expected_norm = 1 / moments[0]
...     expected_mean, r2 = moments[1:] * expected_norm
...     expected_std = np.sqrt(r2 - np.square(expected_mean))
...     np.testing.assert_approx_equal(mean, expected_mean, 5)
...     np.testing.assert_approx_equal(std, expected_std, 5)
...     np.testing.assert_approx_equal(norm, expected_norm, 5)

""")
parameters_to_moments = require_contiguous_aligned(parameters_to_moments)


add_newdoc_ufunc(volume_render, """\
Perform volumetric rendering of a 3D sky map.

Parameters
----------
x : `numpy.ndarray`
    X-coordinate in rendered image
y : `numpy.ndarray`
    Y-coordinate in rendered image
max_distance : float
    Limit of integration from `-max_distance` to `+max_distance`
axis0 : int
    Index of axis to assign to x-coordinate
axis1 : int
    Index of axis to assign to y-coordinate
R : `numpy.ndarray`
    Rotation matrix as provided by `principal_axes`
nest : bool
    HEALPix ordering scheme
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
image : `numpy.ndarray`
    Rendered image

Examples
--------
Test volume rendering of a normal unit sphere...
First, set up the 3D sky map.

>>> nside = 32
>>> npix = ah.nside_to_npix(nside)
>>> prob = np.ones(npix) / npix
>>> distmu = np.zeros(npix)
>>> distsigma = np.ones(npix)
>>> distnorm = np.ones(npix) * 2.0

The conditional distance distribution should be a chi distribution with
3 degrees of freedom.

>>> from scipy.stats import norm, chi
>>> r = np.linspace(0, 10.0)
>>> actual = conditional_pdf(r, distmu[0], distsigma[0], distnorm[0])
>>> expected = chi(3).pdf(r)
>>> np.testing.assert_almost_equal(actual, expected)

Next, run the volume renderer.

>>> dmax = 4.0
>>> n = 64
>>> s = np.logspace(-dmax, dmax, n)
>>> x, y = np.meshgrid(s, s)
>>> R = np.eye(3)
>>> P = volume_render(x, y, dmax, 0, 1, R, False,
...                   prob, distmu, distsigma, distnorm)

Next, integrate analytically.

>>> P_expected = norm.pdf(x) * norm.pdf(y) * (norm.cdf(dmax) - norm.cdf(-dmax))

Compare the two.

>>> np.testing.assert_almost_equal(P, P_expected, decimal=4)

Check that we get the same answer if the input is in ring ordering.
FIXME: this is a very weak test, because the input sky map is isotropic!

>>> P = volume_render(x, y, dmax, 0, 1, R, True,
...                   prob, distmu, distsigma, distnorm)
>>> np.testing.assert_almost_equal(P, P_expected, decimal=4)

Last, check that we don't have a coordinate singularity at the origin.

>>> x = np.concatenate(([0], np.logspace(1 - n, 0, n) * dmax))
>>> y = 0.0
>>> P = volume_render(x, y, dmax, 0, 1, R, False,
...                   prob, distmu, distsigma, distnorm)
>>> P_expected = norm.pdf(x) * norm.pdf(y) * (norm.cdf(dmax) - norm.cdf(-dmax))
>>> np.testing.assert_allclose(P, P_expected, rtol=1e-4)

""")
volume_render = require_contiguous_aligned(volume_render)


add_newdoc_ufunc(marginal_pdf, """\
Calculate all-sky marginal pdf (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
pdf : `numpy.ndarray`
    Marginal probability density according to ansatz.

Examples
--------

>>> npix = 12
>>> prob, distmu, distsigma, distnorm = np.random.uniform(size=(4, 12))
>>> r = np.linspace(0, 1)
>>> pdf_expected = np.dot(
...     conditional_pdf(r[:, np.newaxis], distmu, distsigma, distnorm), prob)
>>> pdf = marginal_pdf(r, prob, distmu, distsigma, distnorm)
>>> np.testing.assert_allclose(pdf, pdf_expected, rtol=1e-4)

""")
marginal_pdf = require_contiguous_aligned(marginal_pdf)


add_newdoc_ufunc(marginal_cdf, """\
Calculate all-sky marginal cdf (ansatz).

Parameters
----------
r : `numpy.ndarray`
    Distance (Mpc)
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
cdf : `numpy.ndarray`
    Marginal cumulative probability according to ansatz.

Examples
--------

>>> npix = 12
>>> prob, distmu, distsigma, distnorm = np.random.uniform(size=(4, 12))
>>> r = np.linspace(0, 1)
>>> cdf_expected = np.dot(
...     conditional_cdf(r[:, np.newaxis], distmu, distsigma, distnorm), prob)
>>> cdf = marginal_cdf(r, prob, distmu, distsigma, distnorm)
>>> np.testing.assert_allclose(cdf, cdf_expected, rtol=1e-4)

""")
marginal_cdf = require_contiguous_aligned(marginal_cdf)


add_newdoc_ufunc(marginal_ppf, """\
Point percent function (inverse cdf) of marginal distribution of distance
(ansatz).

Parameters
----------
p : `numpy.ndarray`
    The cumulative distribution function
prob : `numpy.ndarray`
    Marginal probability (pix^-2)
distmu : `numpy.ndarray`
    Distance location parameter (Mpc)
distsigma : `numpy.ndarray`
    Distance scale parameter (Mpc)
distnorm : `numpy.ndarray`
    Distance normalization factor (Mpc^-2)

Returns
-------
r : `numpy.ndarray`
    Distance at which the cdf is equal to `p`.

Examples
--------

>>> from astropy.utils.misc import NumpyRNGContext
>>> npix = 12
>>> with NumpyRNGContext(0):
...     prob, distmu, distsigma, distnorm = np.random.uniform(size=(4, 12))
>>> r_expected = np.linspace(0.4, 0.7)
>>> cdf = marginal_cdf(r_expected, prob, distmu, distsigma, distnorm)
>>> r = marginal_ppf(cdf, prob, distmu, distsigma, distnorm)
>>> np.testing.assert_allclose(r, r_expected, rtol=1e-4)

""")
marginal_ppf = require_contiguous_aligned(marginal_ppf)


def ud_grade(prob, distmu, distsigma, *args, **kwargs):
    """
    Upsample or downsample a distance-resolved sky map.

    Parameters
    ----------
    prob : `numpy.ndarray`
        Marginal probability (pix^-2)
    distmu : `numpy.ndarray`
        Distance location parameter (Mpc)
    distsigma : `numpy.ndarray`
        Distance scale parameter (Mpc)
    *args, **kwargs :
        Additional arguments to `healpy.ud_grade` (e.g.,
        `nside`, `order_in`, `order_out`).

    Returns
    -------
    prob : `numpy.ndarray`
        Resampled marginal probability (pix^-2)
    distmu : `numpy.ndarray`
        Resampled distance location parameter (Mpc)
    distsigma : `numpy.ndarray`
        Resampled distance scale parameter (Mpc)
    distnorm : `numpy.ndarray`
        Resampled distance normalization factor (Mpc^-2)

    """
    bad = ~(np.isfinite(distmu) & np.isfinite(distsigma))
    distmean, diststd, _ = parameters_to_moments(distmu, distsigma)
    distmean[bad] = 0
    diststd[bad] = 0
    distmean = hp.ud_grade(prob * distmu, *args, power=-2, **kwargs)
    diststd = hp.ud_grade(prob * np.square(diststd), *args, power=-2, **kwargs)
    prob = hp.ud_grade(prob, *args, power=-2, **kwargs)
    distmean /= prob
    diststd = np.sqrt(diststd / prob)
    bad = ~hp.ud_grade(~bad, *args, power=-2, **kwargs)
    distmean[bad] = np.inf
    diststd[bad] = 1
    distmu, distsigma, distnorm = moments_to_parameters(distmean, diststd)
    return prob, distmu, distsigma, distnorm


def _conditional_kde(n, X, Cinv, W):
    Cinv_n = np.dot(Cinv, n)
    cinv = np.dot(n, Cinv_n)
    x = np.dot(Cinv_n, X) / cinv
    w = W * (0.5 / np.pi) * np.sqrt(np.linalg.det(Cinv) / cinv) * np.exp(
        0.5 * (np.square(x) * cinv - (np.dot(Cinv, X) * X).sum(0)))
    return x, cinv, w


def conditional_kde(n, datasets, inverse_covariances, weights):
    return [
        _conditional_kde(n, X, Cinv, W)
        for X, Cinv, W in zip(datasets, inverse_covariances, weights)]


def cartesian_kde_to_moments(n, datasets, inverse_covariances, weights):
    """
    Calculate the marginal probability, conditional mean, and conditional
    standard deviation of a mixture of three-dimensional kernel density
    estimators (KDEs), in a given direction specified by a unit vector.

    Parameters
    ----------
    n : `numpy.ndarray`
        A unit vector; an array of length 3.
    datasets : list of `numpy.ndarray`
        A list 2D Numpy arrays specifying the sample points of the KDEs.
        The first dimension of each array is 3.
    inverse_covariances: list of `numpy.ndarray`
        An array of 3x3 matrices specifying the inverses of the covariance
        matrices of the KDEs. The list has the same length as the datasets
        parameter.
    weights : list
        A list of floating-point weights.

    Returns
    -------
    prob : float
        The marginal probability in direction n, integrated over all distances.
    mean : float
        The conditional mean in direction n.
    std : float
        The conditional standard deviation in direction n.

    Examples
    --------
    >>> # Some imports
    >>> import scipy.stats
    >>> import scipy.integrate
    >>> # Construct random dataset for KDE
    >>> np.random.seed(0)
    >>> nclusters = 5
    >>> ndata = np.random.randint(0, 1000, nclusters)
    >>> covs = [np.random.uniform(0, 1, size=(3, 3)) for _ in range(nclusters)]
    >>> covs = [_ + _.T + 3 * np.eye(3) for _ in covs]
    >>> means = np.random.uniform(-1, 1, size=(nclusters, 3))
    >>> datasets = [np.random.multivariate_normal(m, c, n).T
    ...     for m, c, n in zip(means, covs, ndata)]
    >>> weights = ndata / float(np.sum(ndata))
    >>>
    >>> # Construct set of KDEs
    >>> kdes = [scipy.stats.gaussian_kde(_) for _ in datasets]
    >>>
    >>> # Random unit vector n
    >>> n = np.random.normal(size=3)
    >>> n /= np.sqrt(np.sum(np.square(n)))
    >>>
    >>> # Analytically evaluate conditional mean and std. dev. in direction n
    >>> datasets = [_.dataset for _ in kdes]
    >>> inverse_covariances = [_.inv_cov for _ in kdes]
    >>> result_prob, result_mean, result_std = cartesian_kde_to_moments(
    ...     n, datasets, inverse_covariances, weights)
    >>>
    >>> # Numerically integrate conditional distance moments
    >>> def rkbar(k):
    ...     def integrand(r):
    ...         return r ** k * np.sum([kde(r * n) * weight
    ...             for kde, weight in zip(kdes, weights)])
    ...     integral, err = scipy.integrate.quad(integrand, 0, np.inf)
    ...     return integral
    ...
    >>> r0bar = rkbar(2)
    >>> r1bar = rkbar(3)
    >>> r2bar = rkbar(4)
    >>>
    >>> # Extract conditional mean and std. dev.
    >>> r1bar /= r0bar
    >>> r2bar /= r0bar
    >>> expected_prob = r0bar
    >>> expected_mean = r1bar
    >>> expected_std = np.sqrt(r2bar - np.square(r1bar))
    >>>
    >>> # Check that the two methods give almost the same result
    >>> np.testing.assert_almost_equal(result_prob, expected_prob)
    >>> np.testing.assert_almost_equal(result_mean, expected_mean)
    >>> np.testing.assert_almost_equal(result_std, expected_std)
    >>>
    >>> # Check that KDE is normalized over unit sphere.
    >>> nside = 32
    >>> npix = ah.nside_to_npix(nside)
    >>> prob, _, _ = np.transpose([cartesian_kde_to_moments(
    ...     np.asarray(hp.pix2vec(nside, ipix)),
    ...     datasets, inverse_covariances, weights)
    ...     for ipix in range(npix)])
    >>> result_integral = prob.sum() * hp.nside2pixarea(nside)
    >>> np.testing.assert_almost_equal(result_integral, 1.0, decimal=4)

    """
    # Initialize moments of conditional KDE.
    r0bar = 0
    r1bar = 0
    r2bar = 0

    # Loop over KDEs.
    for X, Cinv, W in zip(datasets, inverse_covariances, weights):
        x, cinv, w = _conditional_kde(n, X, Cinv, W)

        # Accumulate moments of conditional KDE.
        c = 1 / cinv
        x2 = np.square(x)
        a = scipy.special.ndtr(x * np.sqrt(cinv))
        b = np.sqrt(0.5 / np.pi * c) * np.exp(-0.5 * cinv * x2)
        r0bar_ = (x2 + c) * a + x * b
        r1bar_ = x * (x2 + 3 * c) * a + (x2 + 2 * c) * b
        r2bar_ = (x2 * x2 + 6 * x2 * c + 3 * c * c) * a + x * (x2 + 5 * c) * b
        r0bar += np.mean(w * r0bar_)
        r1bar += np.mean(w * r1bar_)
        r2bar += np.mean(w * r2bar_)

    # Normalize moments.
    with np.errstate(invalid='ignore'):
        r1bar /= r0bar
        r2bar /= r0bar
    var = r2bar - np.square(r1bar)

    # Handle invalid values.
    if var >= 0:
        mean = r1bar
        std = np.sqrt(var)
    else:
        mean = np.inf
        std = 1.0
    prob = r0bar

    # Done!
    return prob, mean, std


def principal_axes(prob, distmu, distsigma, nest=False):
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    good = np.isfinite(prob) & np.isfinite(distmu) & np.isfinite(distsigma)
    ipix = np.flatnonzero(good)
    distmean, diststd, _ = parameters_to_moments(distmu[good], distsigma[good])
    mass = prob[good] * (np.square(diststd) + np.square(distmean))
    xyz = np.asarray(hp.pix2vec(nside, ipix, nest=nest))
    cov = np.dot(xyz * mass, xyz.T)
    L, V = np.linalg.eigh(cov)
    if np.linalg.det(V) < 0:
        V = -V
    return V


def parameters_to_marginal_moments(prob, distmu, distsigma):
    """Calculate the marginal (integrated all-sky) mean and standard deviation
    of distance from the ansatz parameters.

    Parameters
    ----------
    prob : `numpy.ndarray`
        Marginal probability (pix^-2)
    distmu : `numpy.ndarray`
        Distance location parameter (Mpc)
    distsigma : `numpy.ndarray`
        Distance scale parameter (Mpc)

    Returns
    -------
    distmean : float
        Mean distance (Mpc)
    diststd : float
        Std. deviation of distance (Mpc)

    """
    good = np.isfinite(prob) & np.isfinite(distmu) & np.isfinite(distsigma)
    prob = prob[good]
    distmu = distmu[good]
    distsigma = distsigma[good]
    distmean, diststd, _ = parameters_to_moments(distmu, distsigma)
    rbar = (prob * distmean).sum()
    r2bar = (prob * (np.square(diststd) + np.square(distmean))).sum()
    return rbar, np.sqrt(r2bar - np.square(rbar))


del add_newdoc_ufunc, require_contiguous_aligned

ligo/skymap/healpix_tree.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Multiresolution HEALPix trees
"""
import astropy_healpix as ah
from astropy import units as u
import numpy as np
import healpy as hp
import collections
import itertools

__all__ = ('HEALPIX_MACHINE_ORDER', 'HEALPIX_MACHINE_NSIDE', 'HEALPixTree',
           'adaptive_healpix_histogram', 'interpolate_nested',
           'reconstruct_nested')


# Maximum 64-bit HEALPix resolution.
HEALPIX_MACHINE_ORDER = 29
HEALPIX_MACHINE_NSIDE = ah.level_to_nside(HEALPIX_MACHINE_ORDER)


_HEALPixTreeVisitExtra = collections.namedtuple(
    'HEALPixTreeVisit', 'nside full_nside ipix ipix0 ipix1 value')


_HEALPixTreeVisit = collections.namedtuple(
    'HEALPixTreeVisit', 'nside ipix')


class HEALPixTree:
    """Data structure used internally by the function
    adaptive_healpix_histogram()."""

    def __init__(
            self, samples, max_samples_per_pixel, max_order,
            order=0, needs_sort=True):
        if needs_sort:
            samples = np.sort(samples)
        if len(samples) >= max_samples_per_pixel and order < max_order:
            # All nodes have 4 children, except for the root node,
            # which has 12.
            nchildren = 12 if order == 0 else 4
            self.samples = None
            self.children = [
                HEALPixTree(
                    [], max_samples_per_pixel, max_order, order=order + 1)
                for i in range(nchildren)]
            for ipix, samples in itertools.groupby(
                    samples, self.key_for_order(order)):
                self.children[ipix % nchildren] = HEALPixTree(
                    list(samples), max_samples_per_pixel, max_order,
                    order=order + 1, needs_sort=False)
        else:
            # There are few enough samples that we can make this cell a leaf.
            self.samples = list(samples)
            self.children = None

    @staticmethod
    def key_for_order(order):
        """Create a function that downsamples full-resolution pixel indices."""
        return lambda ipix: ipix >> np.int64(
            2 * (HEALPIX_MACHINE_ORDER - order))

    @property
    def order(self):
        """Return the maximum HEALPix order required to represent this tree,
        which is the same as the tree depth."""
        if self.children is None:
            return 0
        else:
            return 1 + max(child.order for child in self.children)

    def _visit(self, order, full_order, ipix, extra):
        if self.children is None:
            nside = 1 << order
            full_nside = 1 << order
            ipix0 = ipix << 2 * (full_order - order)
            ipix1 = (ipix + 1) << 2 * (full_order - order)
            if extra:
                yield _HEALPixTreeVisitExtra(
                    nside, full_nside, ipix, ipix0, ipix1, self.samples)
            else:
                yield _HEALPixTreeVisit(nside, ipix)
        else:
            for i, child in enumerate(self.children):
                yield from child._visit(
                    order + 1, full_order, (ipix << 2) + i, extra)

    def _visit_depthfirst(self, extra):
        order = self.order
        for ipix, child in enumerate(self.children):
            yield from child._visit(0, order, ipix, extra)

    def _visit_breadthfirst(self, extra):
        return sorted(
            self._visit_depthfirst(extra), lambda _: (_.nside, _.ipix))

    def visit(self, order='depthfirst', extra=True):
        """Traverse the leaves of the HEALPix tree.

        Parameters
        ----------
        order : string, optional
            Traversal order: 'depthfirst' (the default) or 'breadthfirst'.

        extra : bool
            Whether to output extra information about the pixel
            (default is True).

        Yields
        ------
        nside : int
            The HEALPix resolution of the node.

        full_nside : int, present if extra=True
            The HEALPix resolution of the deepest node in the tree.

        ipix : int
            The nested HEALPix index of the node.

        ipix0 : int, present if extra=True
            The start index of the range of pixels spanned by the node at the
            resolution `full_nside`.

        ipix1 : int, present if extra=True
            The end index of the range of pixels spanned by the node at the
            resolution `full_nside`.

        samples : list, present if extra=True
            The list of samples contained in the node.

        Examples
        --------

        >>> ipix = np.arange(12) * HEALPIX_MACHINE_NSIDE**2
        >>> tree = HEALPixTree(ipix, max_samples_per_pixel=1, max_order=1)
        >>> [tuple(_) for _ in tree.visit(extra=False)]
        [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]
        """
        funcs = {'depthfirst': self._visit_depthfirst,
                 'breadthfirst': self._visit_breadthfirst}
        func = funcs[order]
        yield from func(extra)

    @property
    def flat_bitmap(self):
        """Return flattened HEALPix representation."""
        m = np.empty(ah.nside_to_npix(ah.level_to_nside(self.order)))
        for nside, full_nside, ipix, ipix0, ipix1, samples in self.visit():
            pixarea = ah.nside_to_pixel_area(nside).to_value(u.sr)
            m[ipix0:ipix1] = len(samples) / pixarea
        return m


def adaptive_healpix_histogram(
        theta, phi, max_samples_per_pixel, nside=-1, max_nside=-1, nest=False):
    """Adaptively histogram the posterior samples represented by the
    (theta, phi) points using a recursively subdivided HEALPix tree. Nodes are
    subdivided until each leaf contains no more than max_samples_per_pixel
    samples. Finally, the tree is flattened to a fixed-resolution HEALPix image
    with a resolution appropriate for the depth of the tree. If nside is
    specified, the result is resampled to another desired HEALPix resolution.
    """
    # Calculate pixel index of every sample, at the maximum 64-bit resolution.
    #
    # At this resolution, each pixel is only 0.2 mas across; we'll use the
    # 64-bit pixel indices as a proxy for the true sample coordinates so that
    # we don't have to do any trigonometry (aside from the initial hp.ang2pix
    # call).
    ipix = hp.ang2pix(HEALPIX_MACHINE_NSIDE, theta, phi, nest=True)

    # Build tree structure.
    if nside == -1 and max_nside == -1:
        max_order = HEALPIX_MACHINE_ORDER
    elif nside == -1:
        max_order = ah.nside_to_level(max_nside)
    elif max_nside == -1:
        max_order = ah.nside_to_level(nside)
    else:
        max_order = ah.nside_to_level(min(nside, max_nside))
    tree = HEALPixTree(ipix, max_samples_per_pixel, max_order)

    # Compute a flattened bitmap representation of the tree.
    p = tree.flat_bitmap

    # If requested, resample the tree to the output resolution.
    if nside != -1:
        p = hp.ud_grade(p, nside, order_in='NESTED', order_out='NESTED')

    # Normalize.
    p /= np.sum(p)

    if not nest:
        p = hp.reorder(p, n2r=True)

    # Done!
    return p


def _interpolate_level(m):
    """Recursive multi-resolution interpolation. Modifies `m` in place."""
    # Determine resolution.
    npix = len(m)

    if npix > 12:
        # Determine which pixels comprise multi-pixel tiles.
        ipix = np.flatnonzero(
            (m[0::4] == m[1::4]) &
            (m[0::4] == m[2::4]) &
            (m[0::4] == m[3::4]))

        if len(ipix):
            ipix = 4 * ipix + np.expand_dims(np.arange(4, dtype=np.intp), 1)
            ipix = ipix.T.ravel()

            nside = ah.npix_to_nside(npix)

            # Downsample.
            m_lores = hp.ud_grade(
                m, nside // 2, order_in='NESTED', order_out='NESTED')

            # Interpolate recursively.
            _interpolate_level(m_lores)

            # Record interpolated multi-pixel tiles.
            m[ipix] = hp.get_interp_val(
                m_lores, *hp.pix2ang(nside, ipix, nest=True), nest=True)


def interpolate_nested(m, nest=False):
    """
    Apply bilinear interpolation to a multiresolution HEALPix map, assuming
    that runs of pixels containing identical values are nodes of the tree. This
    smooths out the stair-step effect that may be noticeable in contour plots.

    Here is how it works. Consider a coarse tile surrounded by base tiles, like
    this::

                +---+---+
                |   |   |
                +-------+
                |   |   |
        +---+---+---+---+---+---+
        |   |   |       |   |   |
        +-------+       +-------+
        |   |   |       |   |   |
        +---+---+---+---+---+---+
                |   |   |
                +-------+
                |   |   |
                +---+---+

    The value within the central coarse tile is computed by downsampling the
    sky map (averaging the fine tiles), upsampling again (with bilinear
    interpolation), and then finally copying the interpolated values within the
    coarse tile back to the full-resolution sky map. This process is applied
    recursively at all successive HEALPix resolutions.

    Note that this method suffers from a minor discontinuity artifact at the
    edges of regions of coarse tiles, because it temporarily treats the
    bordering fine tiles as constant. However, this artifact seems to have only
    a minor effect on generating contour plots.

    Parameters
    ----------

    m: `~numpy.ndarray`
        a HEALPix array

    nest: bool, default: False
        Whether the input array is stored in the `NESTED` indexing scheme
        (True) or the `RING` indexing scheme (False).

    """
    # Convert to nest indexing if necessary, and make sure that we are working
    # on a copy.
    if nest:
        m = m.copy()
    else:
        m = hp.reorder(m, r2n=True)

    _interpolate_level(m)

    # Convert to back ring indexing if necessary
    if not nest:
        m = hp.reorder(m, n2r=True)

    # Done!
    return m


def _reconstruct_nested_breadthfirst(m, extra):
    m = np.asarray(m)
    max_npix = len(m)
    max_nside = ah.npix_to_nside(max_npix)
    max_order = ah.nside_to_level(max_nside)
    seen = np.zeros(max_npix, dtype=bool)

    for order in range(max_order + 1):
        nside = ah.level_to_nside(order)
        npix = ah.nside_to_npix(nside)
        skip = max_npix // npix
        if skip > 1:
            b = m.reshape(-1, skip)
            a = b[:, 0].reshape(-1, 1)
            b = b[:, 1:]
            aseen = seen.reshape(-1, skip)
            eq = ((a == b) | ((a != a) & (b != b))).all(1) & (~aseen).all(1)
        else:
            eq = ~seen
        for ipix in np.flatnonzero(eq):
            ipix0 = ipix * skip
            ipix1 = (ipix + 1) * skip
            seen[ipix0:ipix1] = True
            if extra:
                yield _HEALPixTreeVisitExtra(
                    nside, max_nside, ipix, ipix0, ipix1, m[ipix0])
            else:
                yield _HEALPixTreeVisit(nside, ipix)


def _reconstruct_nested_depthfirst(m, extra):
    result = sorted(
        _reconstruct_nested_breadthfirst(m, True),
        key=lambda _: _.ipix0)
    if not extra:
        result = (_HEALPixTreeVisit(_.nside, _.ipix) for _ in result)
    return result


def reconstruct_nested(m, order='depthfirst', extra=True):
    """Reconstruct the leaves of a multiresolution tree.

    Parameters
    ----------
    m : `~numpy.ndarray`
        A HEALPix array in the NESTED ordering scheme.

    order : {'depthfirst', 'breadthfirst'}, optional
        Traversal order: 'depthfirst' (the default) or 'breadthfirst'.

    extra : bool
        Whether to output extra information about the pixel (default is True).

    Yields
    ------
    nside : int
        The HEALPix resolution of the node.

    full_nside : int, present if extra=True
        The HEALPix resolution of the deepest node in the tree.

    ipix : int
        The nested HEALPix index of the node.

    ipix0 : int, present if extra=True
        The start index of the range of pixels spanned by the node at the
        resolution `full_nside`.

    ipix1 : int, present if extra=True
        The end index of the range of pixels spanned by the node at the
        resolution `full_nside`.

    value : list, present if extra=True
        The value of the map at the node.

    Examples
    --------

    An nside=1 array of all zeros:

    >>> m = np.zeros(12)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=1 array of distinct values:

    >>> m = range(12)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=8 array of zeros:

    >>> m = np.zeros(768)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=2 array, all zeros except for four consecutive distinct elements:

    >>> m = np.zeros(48); m[:4] = range(4)
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (2, 0), (2, 1), (2, 2), (2, 3)]

    Same, but in depthfirst order:

    >>> result = reconstruct_nested(m, order='depthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(2, 0), (2, 1), (2, 2), (2, 3), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11)]

    An nside=2 array, all elements distinct except for four consecutive zeros:

    >>> m = np.arange(48); m[:4] = 0
    >>> result = reconstruct_nested(m, order='breadthfirst', extra=False)
    >>> [tuple(_) for _ in result]
    [(1, 0), (2, 4), (2, 5), (2, 6), (2, 7), (2, 8), (2, 9), (2, 10), (2, 11), (2, 12), (2, 13), (2, 14), (2, 15), (2, 16), (2, 17), (2, 18), (2, 19), (2, 20), (2, 21), (2, 22), (2, 23), (2, 24), (2, 25), (2, 26), (2, 27), (2, 28), (2, 29), (2, 30), (2, 31), (2, 32), (2, 33), (2, 34), (2, 35), (2, 36), (2, 37), (2, 38), (2, 39), (2, 40), (2, 41), (2, 42), (2, 43), (2, 44), (2, 45), (2, 46), (2, 47)]
    """
    funcs = {'depthfirst': _reconstruct_nested_depthfirst,
             'breadthfirst': _reconstruct_nested_breadthfirst}
    func = funcs[order]
    yield from func(m, extra)

ligo/skymap/kde.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
#
# Copyright (C) 2012-2020  Will M. Farr <will.farr@ligo.org>
#                          Leo P. Singer <leo.singer@ligo.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import copyreg
from functools import partial

import astropy_healpix as ah
from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy import units as u
from astropy.utils.misc import NumpyRNGContext
import healpy as hp
import logging
import numpy as np
from scipy.stats import gaussian_kde

from . import distance
from . import moc
from . import omp
from .coordinates import EigenFrame
from .util import progress_map

log = logging.getLogger()

__all__ = ('BoundedKDE', 'Clustered2DSkyKDE', 'Clustered3DSkyKDE',
           'Clustered2Plus1DSkyKDE')


class BoundedKDE(gaussian_kde):
    """Density estimation using a KDE on bounded domains.

    Bounds can be any combination of low or high (if no bound, set to
    ``float('inf')`` or ``float('-inf')``), and can be periodic or
    non-periodic.  Cannot handle topologies that have
    multi-dimensional periodicities; will only handle topologies that
    are direct products of (arbitrary numbers of) R, [0,1], and S1.

    Parameters
    ----------
    pts : :class:`numpy.ndarray`
        ``(Ndim, Npts)`` shaped array of points (as in :class:`gaussian_kde`).
    low
        Lower bounds; if ``None``, assume no lower bounds.
    high
        Upper bounds; if ``None``, assume no upper bounds.
    periodic
        Boolean array giving periodicity in each dimension; if
        ``None`` assume no dimension is periodic.
    bw_method : optional
        Bandwidth estimation method (see :class:`gaussian_kde`).

    """

    def __init__(self, pts, low=-np.inf, high=np.inf, periodic=False,
                 bw_method=None):

        super().__init__(pts, bw_method=bw_method)
        self._low = np.broadcast_to(
            low, self.d).astype(self.dataset.dtype)
        self._high = np.broadcast_to(
            high, self.d).astype(self.dataset.dtype)
        self._periodic = np.broadcast_to(
            periodic, self.d).astype(bool)

    def evaluate(self, pts):
        """Evaluate the KDE at the given points."""
        pts = np.atleast_2d(pts)
        d, m = pts.shape
        if d != self.d and d == 1 and m == self.d:
            pts = pts.T

        pts_orig = pts
        pts = np.copy(pts_orig)

        den = super().evaluate(pts)

        for i, (low, high, period) in enumerate(zip(self._low, self._high,
                                                    self._periodic)):
            if period:
                p = high - low

                pts[i, :] += p
                den += super().evaluate(pts)

                pts[i, :] -= 2.0 * p
                den += super().evaluate(pts)

                pts[i, :] = pts_orig[i, :]

            else:
                if not np.isneginf(low):
                    pts[i, :] = 2.0 * low - pts[i, :]
                    den += super().evaluate(pts)
                    pts[i, :] = pts_orig[i, :]

                if not np.isposinf(high):
                    pts[i, :] = 2.0 * high - pts[i, :]
                    den += super().evaluate(pts)
                    pts[i, :] = pts_orig[i, :]

        return den

    __call__ = evaluate

    def quantile(self, pt):
        """Quantile of ``pt``, evaluated by a greedy algorithm.

        Parameters
        ----------
        pt
            The point at which the quantile value is to be computed.

        Notes
        -----
        The quantile of ``pt`` is the fraction of points used to construct the
        KDE that have a lower KDE density than ``pt``.

        """
        return np.count_nonzero(self(self.dataset) < self(pt)) / self.n


def km_assign(mus, cov, pts):
    """Implement the assignment step in the k-means algorithm.

    Given a set of centers, ``mus``, a covariance matrix used to produce a
    metric on the space, ``cov``, and a set of points, ``pts`` (shape ``(npts,
    ndim)``), assigns each point to its nearest center, returning an array of
    indices of shape ``(npts,)`` giving the assignments.
    """
    k = mus.shape[0]
    n = pts.shape[0]

    dists = np.zeros((k, n))

    for i, mu in enumerate(mus):
        dx = pts - mu
        try:
            dists[i, :] = np.sum(dx * np.linalg.solve(cov, dx.T).T, axis=1)
        except np.linalg.LinAlgError:
            dists[i, :] = np.nan

    return np.nanargmin(dists, axis=0)


def km_centroids(pts, assign, k):
    """Implement the centroid-update step of the k-means algorithm.

    Given a set of points, ``pts``, of shape ``(npts, ndim)``, and an
    assignment of each point to a region, ``assign``, and the number of means,
    ``k``, returns an array of shape ``(k, ndim)`` giving the centroid of each
    region.
    """
    mus = np.zeros((k, pts.shape[1]))
    for i in range(k):
        sel = assign == i
        if np.sum(sel) > 0:
            mus[i, :] = np.mean(pts[sel, :], axis=0)
        else:
            mus[i, :] = pts[np.random.randint(pts.shape[0]), :]

    return mus


def k_means(pts, k):
    """Perform k-means clustering on the set of points.

    Parameters
    ----------
    pts
        Array of shape ``(npts, ndim)`` giving the points on which k-means is
        to operate.
    k
        Positive integer giving the number of regions.

    Returns
    -------
    centroids
        An ``(k, ndim)`` array giving the centroid of each region.
    assign
        An ``(npts,)`` array of integers between 0 (inclusive) and k
        (exclusive) indicating the assignment of each point to a region.

    """
    assert pts.shape[0] > k, 'must have more points than means'

    cov = np.cov(pts, rowvar=0)

    mus = np.random.permutation(pts)[:k, :]
    assign = km_assign(mus, cov, pts)
    while True:
        old_assign = assign

        mus = km_centroids(pts, assign, k)
        assign = km_assign(mus, cov, pts)

        if np.all(assign == old_assign):
            break

    return mus, assign


def _cluster(cls, pts, trials, i, seed, jobs):
    # FIXME: This is an inelegant hack to disable OpenMP inside subprocesses.
    # See https://git.ligo.org/lscsoft/ligo.skymap/issues/7.
    if jobs != 1:
        omp.num_threads = 1

    k = i // trials
    if k == 0:
        raise ValueError('Expected at least one cluster')
    try:
        if k == 1:
            assign = np.zeros(len(pts), dtype=np.intp)
        else:
            with NumpyRNGContext(i + seed):
                _, assign = k_means(pts, k)
        obj = cls(pts, assign=assign)
    except np.linalg.LinAlgError:
        return -np.inf,
    else:
        return obj.bic, k, obj.kdes


class ClusteredKDE:

    def __init__(self, pts, max_k=40, trials=5, assign=None, jobs=1):
        self.jobs = jobs
        if assign is None:
            log.info('clustering ...')
            # Make sure that each thread gets a different random number state.
            # We start by drawing a random integer s in the main thread, and
            # then the i'th subprocess will seed itself with the integer i + s.
            #
            # The seed must be an unsigned 32-bit integer, so if there are n
            # threads, then s must be drawn from the interval [0, 2**32 - n).
            seed = np.random.randint(0, 2**32 - max_k * trials)
            func = partial(_cluster, type(self), pts, trials, seed=seed,
                           jobs=jobs)
            self.bic, self.k, self.kdes = max(
                self._map(func, range(trials, (max_k + 1) * trials)),
                key=lambda items: items[:2])
        else:
            # Build KDEs for each cluster, skipping degenerate clusters
            self.kdes = []
            npts, ndim = pts.shape
            self.k = assign.max() + 1
            for i in range(self.k):
                sel = (assign == i)
                cluster_pts = pts[sel, :]
                # Equivalent to but faster than len(set(pts))
                nuniq = len(np.unique(cluster_pts, axis=0))
                # Skip if there are fewer unique points than dimensions
                if nuniq <= ndim:
                    continue
                try:
                    kde = gaussian_kde(cluster_pts.T)
                except (np.linalg.LinAlgError, ValueError):
                    # If there are fewer unique points than degrees of freedom,
                    # then the KDE will fail because the covariance matrix is
                    # singular. In that case, don't bother adding that cluster.
                    pass
                else:
                    self.kdes.append(kde)

            # Calculate BIC
            # The number of parameters is:
            #
            # * ndim for each centroid location
            #
            # * (ndim+1)*ndim/2 Kernel covariances for each cluster
            #
            # * one weighting factor for the cluster (minus one for the
            #   overall constraint that the weights must sum to one)
            nparams = (self.k * ndim +
                       0.5 * self.k * (ndim + 1) * ndim + self.k - 1)
            with np.errstate(divide='ignore'):
                self.bic = (
                    np.sum(np.log(self.eval_kdes(pts))) -
                    0.5 * nparams * np.log(npts))

    def eval_kdes(self, pts):
        pts = pts.T
        return sum(w * kde(pts) for w, kde in zip(self.weights, self.kdes))

    def __call__(self, pts):
        return self.eval_kdes(pts)

    @property
    def weights(self):
        """Get the cluster weights: the fraction of the points within each
        cluster.
        """
        w = np.asarray([kde.n for kde in self.kdes])
        return w / np.sum(w)

    def _map(self, func, items):
        return progress_map(func, items, jobs=self.jobs)


class SkyKDE(ClusteredKDE):

    @classmethod
    def transform(cls, pts):
        """Override in sub-classes to transform points."""
        raise NotImplementedError

    def __init__(self, pts, max_k=40, trials=5, assign=None, jobs=1):
        if assign is None:
            pts = self.transform(pts)
        super().__init__(
            pts, max_k=max_k, trials=trials, assign=assign, jobs=jobs)

    def __call__(self, pts):
        return super().__call__(self.transform(pts))

    def _bayestar_adaptive_grid(self, top_nside=16, rounds=8):
        """Implement of the BAYESTAR adaptive mesh refinement scheme as
        described in Section VI of Singer & Price 2016, PRD, 93, 024013
        :doi:`10.1103/PhysRevD.93.024013`.

        FIXME: Consider refactoring BAYESTAR itself to perform the adaptation
        step in Python.
        """
        top_npix = ah.nside_to_npix(top_nside)
        nrefine = top_npix // 4
        cells = zip([0] * nrefine, [top_nside // 2] * nrefine, range(nrefine))
        for iround in range(rounds - 1):
            print('adaptive refinement round {} of {} ...'.format(
                  iround + 1, rounds - 1))
            cells = sorted(cells, key=lambda p_n_i: p_n_i[0] / p_n_i[1]**2)
            new_nside, new_ipix = np.transpose([
                (nside * 2, ipix * 4 + i)
                for _, nside, ipix in cells[-nrefine:] for i in range(4)])
            theta, phi = hp.pix2ang(new_nside, new_ipix, nest=True)
            ra = phi
            dec = 0.5 * np.pi - theta
            p = self(np.column_stack((ra, dec)))
            cells[-nrefine:] = zip(p, new_nside, new_ipix)
        return cells

    def as_healpix(self, top_nside=16):
        """Return a HEALPix multi-order map of the posterior density."""
        post, nside, ipix = zip(*self._bayestar_adaptive_grid(
            top_nside=top_nside))
        post = np.asarray(list(post))
        nside = np.asarray(list(nside))
        ipix = np.asarray(list(ipix))

        # Make sure that sky map is normalized (it should be already)
        post /= np.sum(post * ah.nside_to_pixel_area(nside).to_value(u.sr))

        # Convert from NESTED to UNIQ pixel indices
        order = np.log2(nside).astype(int)
        uniq = moc.nest2uniq(order.astype(np.int8), ipix)

        # Done!
        return Table([uniq, post], names=['UNIQ', 'PROBDENSITY'], copy=False)


# We have to put in some hooks to make instances of Clustered2DSkyKDE picklable
# because we dynamically create subclasses with different values of the 'frame'
# class variable. This gets even trickier because we need both the class and
# instance objects to be picklable.


class _Clustered2DSkyKDEMeta(type):  # noqa: N802
    """Metaclass to make dynamically created subclasses of Clustered2DSkyKDE
    picklable.
    """


def _Clustered2DSkyKDEMeta_pickle(cls):  # noqa: N802
    """Pickle dynamically created subclasses of Clustered2DSkyKDE."""
    return type, (cls.__name__, cls.__bases__, {'frame': cls.frame})


# Register function to pickle subclasses of Clustered2DSkyKDE.
copyreg.pickle(_Clustered2DSkyKDEMeta, _Clustered2DSkyKDEMeta_pickle)


def _Clustered2DSkyKDE_factory(name, frame):  # noqa: N802
    """Unpickle instances of dynamically created subclasses of
    Clustered2DSkyKDE.

    FIXME: In Python 3, we could make this a class method of Clustered2DSkyKDE.
    Unfortunately, Python 2 is picky about pickling bound class methods.
    """
    new_cls = type(name, (Clustered2DSkyKDE,), {'frame': frame})
    return super(Clustered2DSkyKDE, Clustered2DSkyKDE).__new__(new_cls)


class Clustered2DSkyKDE(SkyKDE, metaclass=_Clustered2DSkyKDEMeta):
    r"""Represents a kernel-density estimate of a sky-position PDF that has
    been decomposed into clusters, using a different kernel for each
    cluster.

    The estimated PDF is

    .. math::

      p\left( \vec{\theta} \right) = \sum_{i = 0}^{k-1} \frac{N_i}{N}
      \sum_{\vec{x} \in C_i} N\left[\vec{x}, \Sigma_i\right]\left( \vec{\theta}
      \right)

    where :math:`C_i` is the set of points belonging to cluster
    :math:`i`, :math:`N_i` is the number of points in this cluster,
    :math:`\Sigma_i` is the optimally-converging KDE covariance
    associated to cluster :math:`i`.

    The number of clusters, :math:`k` is chosen to maximize the `BIC
    <http://en.wikipedia.org/wiki/Bayesian_information_criterion>`_
    for the given set of points being drawn from the clustered KDE.
    The points are assigned to clusters using the k-means algorithm,
    with a decorrelated metric.  The overall clustering behavior is
    similar to the well-known `X-Means
    <http://www.cs.cmu.edu/~dpelleg/download/xmeans.pdf>`_ algorithm.
    """

    frame = None

    @classmethod
    def transform(cls, pts):
        pts = SkyCoord(*pts.T, unit='rad').transform_to(cls.frame).spherical
        return np.column_stack((pts.lon.rad, np.sin(pts.lat.rad)))

    def __new__(cls, pts, *args, **kwargs):
        frame = EigenFrame.for_coords(SkyCoord(*pts.T, unit='rad'))
        name = '{:s}_{:x}'.format(cls.__name__, id(frame))
        new_cls = type(name, (cls,), {'frame': frame})
        return super().__new__(new_cls)

    def __reduce__(self):
        """Pickle instances of dynamically created subclasses of
        Clustered2DSkyKDE.
        """
        factory_args = self.__class__.__name__, self.frame
        return _Clustered2DSkyKDE_factory, factory_args, self.__dict__

    def eval_kdes(self, pts):
        base = super().eval_kdes
        dphis = (0.0, 2 * np.pi, -2 * np.pi)
        phi, z = pts.T
        return sum(base(np.column_stack((phi + dphi, z))) for dphi in dphis)


class Clustered3DSkyKDE(SkyKDE):
    """Like :class:`Clustered2DSkyKDE`, but clusters in 3D
    space.  Can compute volumetric posterior density (per cubic Mpc),
    and also produce Healpix maps of the mean and standard deviation
    of the log-distance.
    """

    @classmethod
    def transform(cls, pts):
        return SkyCoord(*pts.T, unit='rad').cartesian.xyz.value.T

    def __call__(self, pts, distances=False):
        """Given an array of positions in RA, DEC, compute the marginal sky
        posterior and optinally the conditional distance parameters.
        """
        func = partial(distance.cartesian_kde_to_moments,
                       datasets=[_.dataset for _ in self.kdes],
                       inverse_covariances=[_.inv_cov for _ in self.kdes],
                       weights=self.weights)
        probdensity, mean, std = zip(*self._map(func, self.transform(pts)))
        if distances:
            mu, sigma, norm = distance.moments_to_parameters(mean, std)
            return probdensity, mu, sigma, norm
        else:
            return probdensity

    def posterior_spherical(self, pts):
        """Evaluate the posterior probability density in spherical polar
        coordinates, as a function of (ra, dec, distance).
        """
        return super().__call__(pts)

    def as_healpix(self, top_nside=16):
        """Return a HEALPix multi-order map of the posterior density and
        conditional distance distribution parameters.
        """
        m = super().as_healpix(top_nside=top_nside)
        order, ipix = moc.uniq2nest(m['UNIQ'])
        nside = 2 ** order.astype(int)
        theta, phi = hp.pix2ang(nside, ipix, nest=True)
        p = np.column_stack((phi, 0.5 * np.pi - theta))
        print('evaluating distance layers ...')
        _, m['DISTMU'], m['DISTSIGMA'], m['DISTNORM'] = self(p, distances=True)
        return m


class Clustered2Plus1DSkyKDE(Clustered3DSkyKDE):
    """A hybrid sky map estimator that uses a 2D clustered KDE for the marginal
    distribution as a function of (RA, Dec) and a 3D clustered KDE for the
    conditional distance distribution.
    """

    def __init__(self, pts, max_k=40, trials=5, assign=None, jobs=1):
        if assign is None:
            self.twod = Clustered2DSkyKDE(
                pts, max_k=max_k, trials=trials, assign=assign, jobs=jobs)
        super().__init__(
            pts, max_k=max_k, trials=trials, assign=assign, jobs=jobs)

    def __call__(self, pts, distances=False):
        probdensity = self.twod(pts)
        if distances:
            _, distmu, distsigma, distnorm = super().__call__(
                pts, distances=True)
            return probdensity, distmu, distsigma, distnorm
        else:
            return probdensity

    def posterior_spherical(self, pts):
        """Evaluate the posterior probability density in spherical polar
        coordinates, as a function of (ra, dec, distance).
        """
        return self(pts) * super().posterior_spherical(pts) / super().__call__(
            pts)

ligo/skymap/moc.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
#
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Support for HEALPix UNIQ pixel indexing [1]_ and multi-order coverage (MOC)
maps [2]_.

References
----------
.. [1] Reinecke & Hivon, 2015. "Efficient data structures for masks on 2D
       grids." AA 580, A132. :doi:`10.1051/0004-6361/201526549`
.. [2] Boch et al., 2014. "MOC - HEALPix Multi-Order Coverage map." IVOA
       Recommendation <http://ivoa.net/documents/MOC/>.

"""

from astropy import table
import numpy as np
from numpy.lib.recfunctions import repack_fields

from .core import nest2uniq, uniq2nest, uniq2order, uniq2pixarea, uniq2ang
from .core import rasterize as _rasterize
from .util.numpy import add_newdoc_ufunc, require_contiguous_aligned

__all__ = ('nest2uniq', 'uniq2nest', 'uniq2order', 'uniq2pixarea',
           'uniq2ang', 'rasterize')


add_newdoc_ufunc(nest2uniq, """\
Convert a pixel index from NESTED to NUNIQ ordering.

Parameters
----------
order : `numpy.ndarray`
    HEALPix resolution order, the logarithm base 2 of `nside`
ipix : `numpy.ndarray`
    NESTED pixel index

Returns
-------
uniq : `numpy.ndarray`
    NUNIQ pixel index

""")
nest2uniq = require_contiguous_aligned(nest2uniq)


add_newdoc_ufunc(uniq2order, """\
Determine the HEALPix resolution order of a HEALPix NESTED index.

Parameters
----------
uniq : `numpy.ndarray`
    NUNIQ pixel index

Returns
-------
order : `numpy.ndarray`
    HEALPix resolution order, the logarithm base 2 of `nside`

""")
uniq2order = require_contiguous_aligned(uniq2order)


add_newdoc_ufunc(uniq2pixarea, """\
Determine the area of a HEALPix NESTED index.

Parameters
----------
uniq : `numpy.ndarray`
    NUNIQ pixel index

Returns
-------
area : `numpy.ndarray`
    The pixel's area in steradians

""")
uniq2pixarea = require_contiguous_aligned(uniq2pixarea)


add_newdoc_ufunc(uniq2nest, """\
Convert a pixel index from NUNIQ to NESTED ordering.

Parameters
----------
uniq : `numpy.ndarray`
    NUNIQ pixel index

Returns
-------
order : `numpy.ndarray`
    HEALPix resolution order (logarithm base 2 of `nside`)
ipix : `numpy.ndarray`
    NESTED pixel index

""")
uniq2nest = require_contiguous_aligned(uniq2nest)


def rasterize(moc_data, order=None):
    """Convert a multi-order HEALPix dataset to fixed-order NESTED ordering.

    Parameters
    ----------
    moc_data : `numpy.ndarray`
        A multi-order HEALPix dataset stored as a Numpy record array whose
        first column is called UNIQ and contains the NUNIQ pixel index. Every
        point on the unit sphere must be contained in exactly one pixel in the
        dataset.
    order : int, optional
        The desired output resolution order, or :obj:`None` for the maximum
        resolution present in the dataset.

    Returns
    -------
    nested_data : `numpy.ndarray`
        A fixed-order, NESTED-ordering HEALPix dataset with all of the columns
        that were in moc_data, with the exception of the UNIQ column.

    """
    if order is None or order < 0:
        order = -1
    else:
        orig_order, orig_nest = uniq2nest(moc_data['UNIQ'])
        to_downsample = order < orig_order
        if np.any(to_downsample):
            to_keep = table.Table(moc_data[~to_downsample], copy=False)
            orig_order = orig_order[to_downsample]
            orig_nest = orig_nest[to_downsample]
            to_downsample = table.Table(moc_data[to_downsample], copy=False)

            ratio = 1 << (2 * np.int64(orig_order - order))
            weights = 1.0 / ratio
            for colname, column in to_downsample.columns.items():
                if colname != 'UNIQ':
                    column *= weights
            to_downsample['UNIQ'] = nest2uniq(order, orig_nest // ratio)
            to_downsample = to_downsample.group_by(
                'UNIQ').groups.aggregate(np.sum)

            moc_data = table.vstack((to_keep, to_downsample))

    # Ensure that moc_data has appropriate padding for each of its columns to
    # be properly aligned in order to avoid undefined behavior.
    moc_data = repack_fields(np.asarray(moc_data), align=True)

    return _rasterize(moc_data, order=order)


del add_newdoc_ufunc, require_contiguous_aligned

ligo/skymap/bayestar/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Rapid sky localization with BAYESTAR [1]_.

References
----------
.. [1] Singer & Price, 2016. "Rapid Bayesian position reconstruction for
       gravitational-wave transients." PRD, 93, 024013.
       :doi:`10.1103/PhysRevD.93.024013`

"""
import inspect
import logging
import os
import sys

from astropy.table import Column, Table
from astropy import units as u
import lal
import lalsimulation
import numpy as np

from .. import distance
from . import filter
from ..io.hdf5 import write_samples
from ..io.fits import metadata_for_version_module
from ..io.events.base import Event
from . import filter  # noqa
from ..kde import Clustered2Plus1DSkyKDE
from .. import moc
from .. import healpix_tree
from .. import version
from .. import core
from ..core import (antenna_factor, signal_amplitude_model,
                    log_posterior_toa_phoa_snr as _log_posterior_toa_phoa_snr)
from ..util.numpy import require_contiguous_aligned
from ..util.stopwatch import Stopwatch
from .ez_emcee import ez_emcee

__all__ = ('derasterize', 'localize', 'rasterize', 'antenna_factor',
           'signal_amplitude_model')

log = logging.getLogger('BAYESTAR')

antenna_factor = require_contiguous_aligned(antenna_factor)
signal_amplitude_model = require_contiguous_aligned(signal_amplitude_model)
_log_posterior_toa_phoa_snr = require_contiguous_aligned(
    _log_posterior_toa_phoa_snr)


# Wrap so that ufunc parameter names are known
def log_posterior_toa_phoa_snr(
        ra, sin_dec, distance, u, twopsi, t, min_distance, max_distance,
        prior_distance_power, cosmology, gmst, sample_rate, epochs, snrs,
        responses, locations, horizons):
    return _log_posterior_toa_phoa_snr(
        ra, sin_dec, distance, u, twopsi, t, min_distance, max_distance,
        prior_distance_power, cosmology, gmst, sample_rate, epochs, snrs,
        responses, locations, horizons)


# Wrap so that fixed parameter values are pulled from keyword arguments
def log_post(params, *args, **kwargs):
    # Names of parameters
    keys = ('ra', 'sin_dec', 'distance', 'u', 'twopsi', 't')

    params = list(params.T)
    kwargs = dict(kwargs)

    return log_posterior_toa_phoa_snr(
        *(kwargs.pop(key) if key in kwargs else params.pop(0) for key in keys),
        *args, **kwargs)


def localize_emcee(args, xmin, xmax, chain_dump=None):
    # Gather posterior samples
    chain = ez_emcee(log_post, xmin, xmax, args=args, vectorize=True)

    # Transform back from sin_dec to dec and cos_inclination to inclination
    chain[:, 1] = np.arcsin(chain[:, 1])
    chain[:, 3] = np.arccos(chain[:, 3])

    # Optionally save posterior sample chain to file.
    if chain_dump:
        _, ndim = chain.shape
        names = 'ra dec distance inclination twopsi time'.split()[:ndim]
        write_samples(Table(rows=chain, names=names, copy=False), chain_dump,
                      path='/bayestar/posterior_samples', overwrite=True)

    # Pass a random subset of 1000 points to the KDE, to save time.
    pts = np.random.permutation(chain)[:1000, :3]
    ckde = Clustered2Plus1DSkyKDE(pts)
    return ckde.as_healpix()


def condition(
        event, waveform='o2-uberbank', f_low=30.0,
        enable_snr_series=True, f_high_truncate=0.95):

    if len(event.singles) == 0:
        raise ValueError('Cannot localize an event with zero detectors.')

    singles = event.singles
    if not enable_snr_series:
        singles = [single for single in singles if single.snr is not None]

    ifos = [single.detector for single in singles]

    # Extract SNRs from table.
    snrs = np.ma.asarray([
        np.ma.masked if single.snr is None else single.snr
        for single in singles])

    # Look up physical parameters for detector.
    detectors = [lalsimulation.DetectorPrefixToLALDetector(str(ifo))
                 for ifo in ifos]
    responses = np.asarray([det.response for det in detectors])
    locations = np.asarray([det.location for det in detectors]) / lal.C_SI

    # Power spectra for each detector.
    psds = [single.psd for single in singles]
    psds = [filter.InterpolatedPSD(filter.abscissa(psd), psd.data.data,
                                   f_high_truncate=f_high_truncate)
            for psd in psds]

    log.debug('calculating templates')
    H = filter.sngl_inspiral_psd(waveform, f_min=f_low, **event.template_args)

    log.debug('calculating noise PSDs')
    HS = [filter.signal_psd_series(H, S) for S in psds]

    # Signal models for each detector.
    log.debug('calculating Fisher matrix elements')
    signal_models = [filter.SignalModel(_) for _ in HS]

    # Get SNR=1 horizon distances for each detector.
    horizons = np.asarray([signal_model.get_horizon_distance()
                           for signal_model in signal_models])

    weights = np.ma.asarray([
        1 / np.square(signal_model.get_crb_toa_uncert(snr))
        for signal_model, snr in zip(signal_models, snrs)])

    # Center detector array.
    locations -= (np.sum(locations * weights.reshape(-1, 1), axis=0) /
                  np.sum(weights))

    if enable_snr_series:
        snr_series = [single.snr_series for single in singles]
        if all(s is None for s in snr_series):
            snr_series = None
    else:
        snr_series = None

    # Maximum barycentered arrival time error:
    # |distance from array barycenter to furthest detector| / c + 5 ms.
    # For LHO+LLO, this is 15.0 ms.
    # For an arbitrary terrestrial detector network, the maximum is 26.3 ms.
    max_abs_t = np.max(
        np.sqrt(np.sum(np.square(locations), axis=1))) + 0.005

    # Calculate a minimum arrival time prior for low-bandwidth signals.
    # This is important for early-warning events, for which the light travel
    # time between detectors may be only a tiny fraction of the template
    # autocorrelation time.
    #
    # The period of the autocorrelation function is 2 pi times the timing
    # uncertainty at SNR=1 (see Eq. (24) of [1]_). We want a quarter period:
    # there is a factor of a half because we want the first zero crossing,
    # and another factor of a half because the SNR time series will be cropped
    # to a two-sided interval.
    max_acor_t = max(0.5 * np.pi * signal_model.get_crb_toa_uncert(1)
                     for signal_model in signal_models)
    max_abs_t = max(max_abs_t, max_acor_t)

    if snr_series is None:
        log.warning("No SNR time series found, so we are creating a "
                    "zero-noise SNR time series from the whitened template's "
                    "autocorrelation sequence. The sky localization "
                    "uncertainty may be underestimated.")

        acors, sample_rates = zip(
            *[filter.autocorrelation(_, max_abs_t) for _ in HS])
        sample_rate = sample_rates[0]
        deltaT = 1 / sample_rate
        nsamples = len(acors[0])
        assert all(sample_rate == _ for _ in sample_rates)
        assert all(nsamples == len(_) for _ in acors)
        nsamples = nsamples * 2 - 1

        snr_series = []
        for acor, single in zip(acors, singles):
            series = lal.CreateCOMPLEX8TimeSeries(
                'fake SNR', 0, 0, deltaT, lal.StrainUnit, nsamples)
            series.epoch = single.time - 0.5 * (nsamples - 1) * deltaT
            acor = np.concatenate((np.conj(acor[:0:-1]), acor))
            series.data.data = single.snr * filter.exp_i(single.phase) * acor
            snr_series.append(series)

    # Ensure that all of the SNR time series have the same sample rate.
    # FIXME: for now, the Python wrapper expects all of the SNR time series to
    # also be the same length.
    deltaT = snr_series[0].deltaT
    sample_rate = 1 / deltaT
    if any(deltaT != series.deltaT for series in snr_series):
        raise ValueError('BAYESTAR does not yet support SNR time series with '
                         'mixed sample rates')

    # Ensure that all of the SNR time series have odd lengths.
    if any(len(series.data.data) % 2 == 0 for series in snr_series):
        raise ValueError('SNR time series must have odd lengths')

    # Trim time series to the desired length.
    max_abs_n = int(np.ceil(max_abs_t * sample_rate))
    desired_length = 2 * max_abs_n - 1
    for i, series in enumerate(snr_series):
        length = len(series.data.data)
        if length > desired_length:
            snr_series[i] = lal.CutCOMPLEX8TimeSeries(
                series, length // 2 + 1 - max_abs_n, desired_length)

    # FIXME: for now, the Python wrapper expects all of the SNR time sries to
    # also be the same length.
    nsamples = len(snr_series[0].data.data)
    if any(nsamples != len(series.data.data) for series in snr_series):
        raise ValueError('BAYESTAR does not yet support SNR time series of '
                         'mixed lengths')

    # Perform sanity checks that the middle sample of the SNR time series match
    # the sngl_inspiral records to the nearest sample (plus the smallest
    # representable LIGOTimeGPS difference of 1 nanosecond).
    for ifo, single, series in zip(ifos, singles, snr_series):
        shift = np.abs(0.5 * (nsamples - 1) * series.deltaT +
                       float(series.epoch - single.time))
        if shift >= deltaT + 1e-8:
            raise ValueError('BAYESTAR expects the SNR time series to be '
                             'centered on the single-detector trigger times, '
                             'but {} was off by {} s'.format(ifo, shift))

    # Extract the TOAs in GPS nanoseconds from the SNR time series, assuming
    # that the trigger happened in the middle.
    toas_ns = [series.epoch.ns() + 1e9 * 0.5 * (len(series.data.data) - 1) *
               series.deltaT for series in snr_series]

    # Collect all of the SNR series in one array.
    snr_series = np.vstack([series.data.data for series in snr_series])

    # Center times of arrival and compute GMST at mean arrival time.
    # Pre-center in integer nanoseconds to preserve precision of
    # initial datatype.
    epoch = sum(toas_ns) // len(toas_ns)
    toas = 1e-9 * (np.asarray(toas_ns) - epoch)
    mean_toa = np.average(toas, weights=weights)
    toas -= mean_toa
    epoch += int(np.round(1e9 * mean_toa))
    epoch = lal.LIGOTimeGPS(0, int(epoch))

    # Translate SNR time series back to time of first sample.
    toas -= 0.5 * (nsamples - 1) * deltaT

    # Convert complex SNRS to amplitude and phase
    snrs_abs = np.abs(snr_series)
    snrs_arg = filter.unwrap(np.angle(snr_series))
    snrs = np.stack((snrs_abs, snrs_arg), axis=-1)

    return epoch, sample_rate, toas, snrs, responses, locations, horizons


def condition_prior(horizons, min_distance=None, max_distance=None,
                    prior_distance_power=None, cosmology=False):
    if cosmology:
        log.warning('Enabling cosmological prior. This feature is UNREVIEWED.')

    # If minimum distance is not specified, then default to 0 Mpc.
    if min_distance is None:
        min_distance = 0

    # If maximum distance is not specified, then default to the SNR=4
    # horizon distance of the most sensitive detector.
    if max_distance is None:
        max_distance = max(horizons) / 4

    # If prior_distance_power is not specified, then default to 2
    # (p(r) ~ r^2, uniform in volume).
    if prior_distance_power is None:
        prior_distance_power = 2

    # Raise an exception if 0 Mpc is the minimum effective distance and the
    # prior is of the form r**k for k<0
    if min_distance == 0 and prior_distance_power < 0:
        raise ValueError(('Prior is a power law r^k with k={}, '
                          'undefined at min_distance=0').format(
                              prior_distance_power))

    return min_distance, max_distance, prior_distance_power, cosmology


def localize(
        event, waveform='o2-uberbank', f_low=30.0,
        min_inclination=0, max_inclination=np.pi / 2,
        min_distance=None, max_distance=None, prior_distance_power=None,
        cosmology=False, mcmc=False, chain_dump=None,
        enable_snr_series=True, f_high_truncate=0.95):
    """Localize a compact binary signal using the BAYESTAR algorithm.

    Parameters
    ----------
    event : `ligo.skymap.io.events.Event`
        The event candidate.
    waveform : str, optional
        The name of the waveform approximant.
    f_low : float, optional
        The low frequency cutoff.
    min_distance, max_distance : float, optional
        The limits of integration over luminosity distance, in Mpc
        (default: determine automatically from detector sensitivity).
    prior_distance_power : int, optional
        The power of distance that appears in the prior
        (default: 2, uniform in volume).
    cosmology: bool, optional
        Set to enable a uniform in comoving volume prior (default: false).
    mcmc : bool, optional
        Set to use MCMC sampling rather than more accurate Gaussian quadrature.
    chain_dump : str, optional
        Save posterior samples to this filename if `mcmc` is set.
    enable_snr_series : bool, optional
        Set to False to disable SNR time series.
    f_high_truncate : float, optional
        Truncate the noise power spectral densities at this factor times the
        highest sampled frequency to suppress artifacts caused by incorrect
        PSD conditioning by some matched filter pipelines.

    Returns
    -------
    skymap : `astropy.table.Table`
        A 3D sky map in multi-order HEALPix format.

    """
    # Hide event parameters, but show all other arguments
    def formatvalue(value):
        if isinstance(value, Event):
            return '=...'
        else:
            return '=' + repr(value)

    frame = inspect.currentframe()
    argstr = inspect.formatargvalues(*inspect.getargvalues(frame),
                                     formatvalue=formatvalue)

    stopwatch = Stopwatch()
    stopwatch.start()

    epoch, sample_rate, toas, snrs, responses, locations, horizons = \
        condition(event, waveform=waveform, f_low=f_low,
                  enable_snr_series=enable_snr_series,
                  f_high_truncate=f_high_truncate)

    min_distance, max_distance, prior_distance_power, cosmology = \
        condition_prior(horizons, min_distance, max_distance,
                        prior_distance_power, cosmology)

    gmst = lal.GreenwichMeanSiderealTime(epoch)

    # Time and run sky localization.
    log.debug('starting computationally-intensive section')
    if mcmc:
        max_abs_t = 2 * snrs.data.shape[1] / sample_rate
        if min_inclination != 0 or max_inclination != np.pi / 2:
            log.warn('inclination limits are not supported for MCMC mode')
        args = (min_distance, max_distance, prior_distance_power, cosmology,
                gmst, sample_rate, toas, snrs, responses, locations, horizons)
        skymap = localize_emcee(
            args=args,
            xmin=[0, -1, min_distance, -1, 0, 0],
            xmax=[2 * np.pi, 1, max_distance, 1, 2 * np.pi, 2 * max_abs_t],
            chain_dump=chain_dump)
    else:
        args = (min_inclination, max_inclination, min_distance, max_distance,
                prior_distance_power, cosmology, gmst, sample_rate, toas, snrs,
                responses, locations, horizons)
        skymap, log_bci, log_bsn = core.toa_phoa_snr(*args)
        skymap = Table(skymap, copy=False)
        skymap.meta['log_bci'] = log_bci
        skymap.meta['log_bsn'] = log_bsn

    # Convert distance moments to parameters
    try:
        distmean = skymap.columns.pop('DISTMEAN')
        diststd = skymap.columns.pop('DISTSTD')
    except KeyError:
        distmean, diststd, _ = distance.parameters_to_moments(
            skymap['DISTMU'], skymap['DISTSIGMA'])
    else:
        skymap['DISTMU'], skymap['DISTSIGMA'], skymap['DISTNORM'] = \
            distance.moments_to_parameters(distmean, diststd)

    # Add marginal distance moments
    good = np.isfinite(distmean) & np.isfinite(diststd)
    prob = (moc.uniq2pixarea(skymap['UNIQ']) * skymap['PROBDENSITY'])[good]
    distmean = distmean[good]
    diststd = diststd[good]
    rbar = (prob * distmean).sum()
    r2bar = (prob * (np.square(diststd) + np.square(distmean))).sum()
    skymap.meta['distmean'] = rbar
    skymap.meta['diststd'] = np.sqrt(r2bar - np.square(rbar))

    stopwatch.stop()
    end_time = lal.GPSTimeNow()
    log.info('finished computationally-intensive section in %s', stopwatch)

    # Fill in metadata and return.
    program, _ = os.path.splitext(os.path.basename(sys.argv[0]))
    skymap.meta.update(metadata_for_version_module(version))
    skymap.meta['creator'] = 'BAYESTAR'
    skymap.meta['origin'] = 'LIGO/Virgo'
    skymap.meta['gps_time'] = float(epoch)
    skymap.meta['runtime'] = stopwatch.real
    skymap.meta['instruments'] = {single.detector for single in event.singles}
    skymap.meta['gps_creation_time'] = end_time
    skymap.meta['history'] = [
        '',
        'Generated by calling the following Python function:',
        '{}.{}{}'.format(__name__, frame.f_code.co_name, argstr),
        '',
        'This was the command line that started the program:',
        ' '.join([program] + sys.argv[1:])]

    return skymap


def rasterize(skymap, order=None):
    orig_order, _ = moc.uniq2nest(skymap['UNIQ'].max())

    # Determine whether we need to do nontrivial downsampling.
    downsampling = (order is not None and 0 <= order < orig_order
                    and 'DISTMU' in skymap.dtype.fields.keys())

    # If we are downsampling, then convert from distance parameters to
    # distance moments times probability density so that the averaging
    # that is automatically done by moc.rasterize() correctly marginalizes
    # the moments.
    if downsampling:
        skymap = Table(skymap, copy=True, meta=skymap.meta)

        probdensity = skymap['PROBDENSITY']
        distmu = skymap.columns.pop('DISTMU')
        distsigma = skymap.columns.pop('DISTSIGMA')

        bad = ~(np.isfinite(distmu) & np.isfinite(distsigma))
        distmean, diststd, _ = distance.parameters_to_moments(
            distmu, distsigma)
        distmean[bad] = np.nan
        diststd[bad] = np.nan
        skymap['DISTMEAN'] = probdensity * distmean
        skymap['DISTVAR'] = probdensity * (
            np.square(diststd) + np.square(distmean))

    skymap = Table(moc.rasterize(skymap, order=order),
                   meta=skymap.meta, copy=False)

    # If we are downsampling, then convert back to distance parameters.
    if downsampling:
        distmean = skymap.columns.pop('DISTMEAN') / skymap['PROBDENSITY']
        diststd = np.sqrt(
            skymap.columns.pop('DISTVAR') / skymap['PROBDENSITY']
            - np.square(distmean))
        skymap['DISTMU'], skymap['DISTSIGMA'], skymap['DISTNORM'] = \
            distance.moments_to_parameters(
                distmean, diststd)

    skymap.rename_column('PROBDENSITY', 'PROB')
    skymap['PROB'] *= 4 * np.pi / len(skymap)
    skymap['PROB'].unit = u.pixel ** -1
    return skymap


def derasterize(skymap):
    skymap.rename_column('PROB', 'PROBDENSITY')
    skymap['PROBDENSITY'] *= len(skymap) / (4 * np.pi)
    skymap['PROBDENSITY'].unit = u.steradian ** -1
    nside, _, ipix, _, _, value = zip(
        *healpix_tree.reconstruct_nested(skymap))
    nside = np.asarray(nside)
    ipix = np.asarray(ipix)
    value = np.stack(value)
    uniq = (4 * np.square(nside) + ipix)
    old_units = [column.unit for column in skymap.columns.values()]
    skymap = Table(value, meta=skymap.meta, copy=False)
    for old_unit, column in zip(old_units, skymap.columns.values()):
        column.unit = old_unit
    skymap.add_column(Column(uniq, name='UNIQ'), 0)
    skymap.sort('UNIQ')
    return skymap


def test():
    """Run BAYESTAR C unit tests.

    Examples
    --------
    >>> test()
    0

    """
    return int(core.test())


del require_contiguous_aligned

ligo/skymap/bayestar/ez_emcee.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import numpy as np
from tqdm import tqdm

from .ptemcee import Sampler

__all__ = ('ez_emcee',)


def logp(x, lo, hi):
    return np.where(((x >= lo) & (x <= hi)).all(-1), 0.0, -np.inf)


def ez_emcee(log_prob_fn, lo, hi, nindep=200,
             ntemps=10, nwalkers=None, nburnin=500,
             args=(), kwargs={}, **options):
    r'''Fire-and-forget MCMC sampling using `ptemcee.Sampler`, featuring
    automated convergence monitoring, progress tracking, and thinning.

    The parameters are bounded in the finite interval described by ``lo`` and
    ``hi`` (including ``-np.inf`` and ``np.inf`` for half-infinite or infinite
    domains).

    If run in an interactive terminal, live progress is shown including the
    current sample number, the total required number of samples, time elapsed
    and estimated time remaining, acceptance fraction, and autocorrelation
    length.

    Sampling terminates when all chains have accumulated the requested number
    of independent samples.

    Parameters
    ----------
    log_prob_fn : callable
        The log probability function. It should take as its argument the
        parameter vector as an of length ``ndim``, or if it is vectorized, a 2D
        array with ``ndim`` columns.
    lo : list, `numpy.ndarray`
        List of lower limits of parameters, of length ``ndim``.
    hi : list, `numpy.ndarray`
        List of upper limits of parameters, of length ``ndim``.
    nindep : int, optional
        Minimum number of independent samples.
    ntemps : int, optional
        Number of temperatures.
    nwalkers : int, optional
        Number of walkers. The default is 4 times the number of dimensions.
    nburnin : int, optional
        Number of samples to discard during burn-in phase.

    Returns
    -------
    chain : `numpy.ndarray`
        The thinned and flattened posterior sample chain,
        with at least ``nindep`` * ``nwalkers`` rows
        and exactly ``ndim`` columns.

    Other parameters
    ----------------
    kwargs :
        Extra keyword arguments for `ptemcee.Sampler`.
        *Tip:* Consider setting the `pool` or `vectorized` keyword arguments in
        order to speed up likelihood evaluations.

    Notes
    -----
    The autocorrelation length, which has a complexity of :math:`O(N \log N)`
    in the number of samples, is recalculated at geometrically progressing
    intervals so that its amortized complexity per sample is constant. (In
    simpler terms, as the chains grow longer and the autocorrelation length
    takes longer to compute, we update it less frequently so that it is never
    more expensive than sampling the chain in the first place.)

    Examples
    --------
    >>> from ligo.skymap.bayestar.ez_emcee import ez_emcee
    >>> from matplotlib import pyplot as plt
    >>> import numpy as np
    >>>
    >>> def log_prob(params):
    ...     """Eggbox function"""
    ...     return 5 * np.log((2 + np.cos(0.5 * params).prod(-1)))
    ...
    >>> lo = [-3*np.pi, -3*np.pi]
    >>> hi = [+3*np.pi, +3*np.pi]
    >>> chain = ez_emcee(log_prob, lo, hi, vectorize=True)   # doctest: +SKIP
    Sampling:  51%|██  | 8628/16820 [00:04<00:04, 1966.74it/s, accept=0.535, acl=62]
    >>> plt.plot(chain[:, 0], chain[:, 1], '.')   # doctest: +SKIP

    .. image:: eggbox.png

    '''  # noqa: E501
    lo = np.asarray(lo)
    hi = np.asarray(hi)
    ndim = len(lo)

    if nwalkers is None:
        nwalkers = 4 * ndim

    nsteps = 64

    with tqdm(total=nburnin + nindep * nsteps) as progress:

        sampler = Sampler(nwalkers, ndim, log_prob_fn, logp,
                          ntemps=ntemps, loglargs=args, loglkwargs=kwargs,
                          logpargs=[lo, hi], random=np.random, **options)
        pos = np.random.uniform(lo, hi, (ntemps, nwalkers, ndim))

        # Burn in
        progress.set_description('Burning in')
        for pos, _, _ in sampler.sample(
                pos, iterations=nburnin, storechain=False):
            progress.update()

        sampler.reset()
        acl = np.nan
        while not np.isfinite(acl) or sampler.time < nindep * acl:

            # Advance the chain
            progress.total = nburnin + max(sampler.time + nsteps,
                                           nindep * acl)
            progress.set_description('Sampling')
            for pos, _, _ in sampler.sample(pos, iterations=nsteps):
                progress.update()

            # Refresh convergence statistics
            progress.set_description('Checking')
            acl = sampler.get_autocorr_time()[0].max()
            if np.isfinite(acl):
                acl = max(1, int(np.ceil(acl)))
            accept = np.mean(sampler.acceptance_fraction[0])
            progress.set_postfix(acl=acl, accept=accept)

            # The autocorrelation time calculation has complexity N log N in
            # the number of posterior samples. Only refresh the autocorrelation
            # length estimate on logarithmically spaced samples so that the
            # amortized complexity per sample is constant.
            nsteps *= 2

    chain = sampler.chain[0, :, ::acl, :]
    s = chain.shape
    return chain.reshape((-1, s[-1]))

ligo/skymap/bayestar/filter.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Utility functions for BAYESTAR that are related to matched filtering."""
import logging
import math

import lal
import lalsimulation
import numpy as np
from scipy import interpolate
from scipy import fftpack as fft
from scipy import linalg

log = logging.getLogger('BAYESTAR')


def unwrap(y, *args, **kwargs):
    """Unwrap phases while skipping NaN or infinite values.

    This is a simple wrapper around :meth:`numpy.unwrap` that can handle
    invalid values.

    Examples
    --------
    >>> t = np.arange(0, 2 * np.pi, 0.5)
    >>> y = np.exp(1j * t)
    >>> unwrap(np.angle(y))
    array([0. , 0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. ])
    >>> y[3] = y[4] = y[7] = np.nan
    >>> unwrap(np.angle(y))
    array([0. , 0.5, 1. , nan, nan, 2.5, 3. , nan, 4. , 4.5, 5. , 5.5, 6. ])
    """
    good = np.isfinite(y)
    result = np.empty_like(y)
    result[~good] = y[~good]
    result[good] = np.unwrap(y[good], *args, **kwargs)
    return result


def ceil_pow_2(n):
    """Return the least integer power of 2 that is greater than or equal to n.

    Examples
    --------
    >>> ceil_pow_2(128.0)
    128.0
    >>> ceil_pow_2(0.125)
    0.125
    >>> ceil_pow_2(129.0)
    256.0
    >>> ceil_pow_2(0.126)
    0.25
    >>> ceil_pow_2(1.0)
    1.0

    """
    # frexp splits floats into mantissa and exponent, ldexp does the opposite.
    # For positive numbers, mantissa is in [0.5, 1.).
    mantissa, exponent = math.frexp(n)
    return math.ldexp(
        1 if mantissa >= 0 else float('nan'),
        exponent - 1 if mantissa == 0.5 else exponent
    )


def abscissa(series):
    """Produce the independent variable for a lal TimeSeries or
    FrequencySeries.
    """
    try:
        delta = series.deltaT
        x0 = float(series.epoch)
    except AttributeError:
        delta = series.deltaF
        x0 = series.f0
    return x0 + delta * np.arange(len(series.data.data))


def exp_i(phi):
    return np.cos(phi) + np.sin(phi) * 1j


def truncated_ifft(y, nsamples_out=None):
    r"""Truncated inverse FFT.

    See http://www.fftw.org/pruned.html for a discussion of related algorithms.

    Perform inverse FFT to obtain truncated autocorrelation time series.
    This makes use of a folded DFT for a speedup of::

        log(nsamples)/log(nsamples_out)

    over directly computing the inverse FFT and truncating. Here is how it
    works. Say we have a frequency-domain signal X[k], for 0 ≤ k ≤ N - 1. We
    want to compute its DFT x[n], for 0 ≤ n ≤ M, where N is divisible by M:
    N = cM, for some integer c. The DFT is::

               N - 1
               ______
               \           2 π i k n
        x[n] =  \     exp[-----------] Y[k]
               /               N
              /------
               k = 0

               c - 1   M - 1
               ______  ______
               \       \           2 π i n (m c + j)
             =  \       \     exp[------------------] Y[m c + j]
               /       /                 c M
              /------ /------
               j = 0   m = 0

               c - 1                     M - 1
               ______                    ______
               \           2 π i n j     \           2 π i n m
             =  \     exp[-----------]    \     exp[-----------] Y[m c + j]
               /               N         /               M
              /------                   /------
               j = 0                     m = 0

    So: we split the frequency series into c deinterlaced sub-signals, each of
    length M, compute the DFT of each sub-signal, and add them back together
    with complex weights.

    Parameters
    ----------
    y : `numpy.ndarray`
        Complex input vector.
    nsamples_out : int, optional
        Length of output vector. By default, same as length of input vector.

    Returns
    -------
    x : `numpy.ndarray`
        The first nsamples_out samples of the IFFT of x, zero-padded if

    Examples
    --------
    First generate the IFFT of a random signal:

    >>> nsamples_out = 1024
    >>> y = np.random.randn(nsamples_out) + np.random.randn(nsamples_out) * 1j
    >>> x = fft.ifft(y)

    Now check that the truncated IFFT agrees:

    >>> np.allclose(x, truncated_ifft(y), rtol=1e-15)
    True
    >>> np.allclose(x, truncated_ifft(y, 1024), rtol=1e-15)
    True
    >>> np.allclose(x[:128], truncated_ifft(y, 128), rtol=1e-15)
    True
    >>> np.allclose(x[:1], truncated_ifft(y, 1), rtol=1e-15)
    True
    >>> np.allclose(x[:32], truncated_ifft(y, 32), rtol=1e-15)
    True
    >>> np.allclose(x[:63], truncated_ifft(y, 63), rtol=1e-15)
    True
    >>> np.allclose(x[:25], truncated_ifft(y, 25), rtol=1e-15)
    True
    >>> truncated_ifft(y, 1025)
    Traceback (most recent call last):
      ...
    ValueError: Input is too short: you gave me an input of length 1024, but you asked for an IFFT of length 1025.

    """  # noqa: E501
    nsamples = len(y)
    if nsamples_out is None:
        nsamples_out = nsamples
    elif nsamples_out > nsamples:
        raise ValueError(
            'Input is too short: you gave me an input of length {0}, '
            'but you asked for an IFFT of length {1}.'.format(
                nsamples, nsamples_out))
    elif nsamples & (nsamples - 1):
        raise NotImplementedError(
            'I am too lazy to implement for nsamples that is '
            'not a power of 2.')

    # Find number of FFTs.
    # FIXME: only works if nsamples is a power of 2.
    # Would be better to find the smallest divisor of nsamples that is
    # greater than or equal to nsamples_out.
    nsamples_batch = int(ceil_pow_2(nsamples_out))
    c = nsamples // nsamples_batch

    # FIXME: Implement for real-to-complex FFTs as well.
    twiddle = exp_i(2 * np.pi * np.arange(nsamples_batch) / nsamples)

    x = fft.ifft(y.reshape(nsamples_batch, c).T)

    result = x[-1]
    for row in x[-2::-1]:
        result *= twiddle  # FIXME: check stability of this recurrence relation
        result += row

    # Now need to truncate remaining samples.
    if nsamples_out < nsamples_batch:
        result = result[:nsamples_out]

    return result / c


def get_approximant_and_orders_from_string(s):
    """Determine the approximant, amplitude order, and phase order for a string
    of the form "TaylorT4threePointFivePN". In this example, the waveform is
    "TaylorT4" and the phase order is 7 (twice 3.5). If the input contains the
    substring "restricted" or "Restricted", then the amplitude order is taken
    to be 0. Otherwise, the amplitude order is the same as the phase order.
    """
    # SWIG-wrapped functions apparently do not understand Unicode, but
    # often the input argument will come from a Unicode XML file.
    s = str(s)
    approximant = lalsimulation.GetApproximantFromString(s)
    try:
        phase_order = lalsimulation.GetOrderFromString(s)
    except RuntimeError:
        phase_order = -1
    if 'restricted' in s or 'Restricted' in s:
        amplitude_order = 0
    else:
        amplitude_order = phase_order
    return approximant, amplitude_order, phase_order


def get_f_lso(mass1, mass2):
    """Calculate the GW frequency during the last stable orbit of a compact
    binary.
    """
    return 1 / (6 ** 1.5 * np.pi * (mass1 + mass2) * lal.MTSUN_SI)


def sngl_inspiral_psd(waveform, mass1, mass2,
                      f_min=10, f_final=None, f_ref=None, **kwargs):
    # FIXME: uberbank mass criterion. Should find a way to get this from
    # pipeline output metadata.
    if waveform == 'o1-uberbank':
        log.warning('Template is unspecified; '
                    'using ER8/O1 uberbank criterion')
        if mass1 + mass2 < 4:
            waveform = 'TaylorF2threePointFivePN'
        else:
            waveform = 'SEOBNRv2_ROM_DoubleSpin'
    elif waveform == 'o2-uberbank':
        log.warning('Template is unspecified; '
                    'using ER10/O2 uberbank criterion')
        if mass1 + mass2 < 4:
            waveform = 'TaylorF2threePointFivePN'
        else:
            waveform = 'SEOBNRv4_ROM'
    approx, ampo, phaseo = get_approximant_and_orders_from_string(waveform)
    log.info('Selected template: %s', waveform)

    # Generate conditioned template.
    params = lal.CreateDict()
    lalsimulation.SimInspiralWaveformParamsInsertPNPhaseOrder(params, phaseo)
    lalsimulation.SimInspiralWaveformParamsInsertPNAmplitudeOrder(params, ampo)
    hplus, hcross = lalsimulation.SimInspiralFD(
        m1=float(mass1) * lal.MSUN_SI, m2=float(mass2) * lal.MSUN_SI,
        S1x=float(kwargs.get('spin1x') or 0),
        S1y=float(kwargs.get('spin1y') or 0),
        S1z=float(kwargs.get('spin1z') or 0),
        S2x=float(kwargs.get('spin2x') or 0),
        S2y=float(kwargs.get('spin2y') or 0),
        S2z=float(kwargs.get('spin2z') or 0),
        distance=1e6 * lal.PC_SI, inclination=0, phiRef=0,
        longAscNodes=0, eccentricity=0, meanPerAno=0,
        deltaF=0, f_min=f_min,
        # Note: code elsewhere ensures that the sample rate is at least two
        # times f_final; the factor of 2 below is just a safety factor to make
        # sure that the sample rate is 2-4 times f_final.
        f_max=ceil_pow_2(2 * (f_final or 2048)),
        f_ref=float(f_ref or 0),
        LALparams=params, approximant=approx)

    # Force `plus' and `cross' waveform to be in quadrature.
    h = 0.5 * (hplus.data.data + 1j * hcross.data.data)

    # For inspiral-only waveforms, nullify frequencies beyond ISCO.
    # FIXME: the waveform generation functions pick the end frequency
    # automatically. Shouldn't SimInspiralFD?
    inspiral_only_waveforms = (
        lalsimulation.TaylorF2,
        lalsimulation.SpinTaylorF2,
        lalsimulation.TaylorF2RedSpin,
        lalsimulation.TaylorF2RedSpinTidal,
        lalsimulation.SpinTaylorT4Fourier)
    if approx in inspiral_only_waveforms:
        h[abscissa(hplus) >= get_f_lso(mass1, mass2)] = 0

    # Throw away any frequencies above high frequency cutoff
    h[abscissa(hplus) >= (f_final or 2048)] = 0

    # Drop Nyquist frequency.
    if len(h) % 2:
        h = h[:-1]

    # Create output frequency series.
    psd = lal.CreateREAL8FrequencySeries(
        'signal PSD', 0, hplus.f0, hcross.deltaF, hplus.sampleUnits**2, len(h))
    psd.data.data = abs2(h)

    # Done!
    return psd


def signal_psd_series(H, S):
    n = H.data.data.size
    f = H.f0 + np.arange(1, n) * H.deltaF
    ret = lal.CreateREAL8FrequencySeries(
        'signal PSD / noise PSD', 0, H.f0, H.deltaF, lal.DimensionlessUnit, n)
    ret.data.data[0] = 0
    ret.data.data[1:] = H.data.data[1:] / S(f)
    return ret


def autocorrelation(H, out_duration, normalize=True):
    """Calculate the complex autocorrelation sequence a(t), for t >= 0, of an
    inspiral signal.

    Parameters
    ----------
    H : lal.REAL8FrequencySeries
        Signal PSD series.
    S : callable
        Noise power spectral density function.

    Returns
    -------
    acor : `numpy.ndarray`
        The complex-valued autocorrelation sequence.
    sample_rate : float
        The sample rate.

    """
    # Compute duration of template, rounded up to a power of 2.
    H_len = H.data.data.size
    nsamples = 2 * H_len
    sample_rate = nsamples * H.deltaF

    # Compute autopower spectral density.
    power = np.empty(nsamples, H.data.data.dtype)
    power[:H_len] = H.data.data
    power[H_len:] = 0

    # Determine length of output FFT.
    nsamples_out = int(np.ceil(out_duration * sample_rate))

    acor = truncated_ifft(power, nsamples_out)
    if normalize:
        acor /= np.abs(acor[0])

    # If we have done this right, then the zeroth sample represents lag 0
    if np.all(np.isreal(H.data.data)):
        assert np.argmax(np.abs(acor)) == 0
        assert np.isreal(acor[0])

    # Done!
    return acor, float(sample_rate)


def abs2(y):
    """Return the absolute value squared, :math:`|z|^2` ,for a complex number
    :math:`z`, without performing a square root.
    """
    return np.square(y.real) + np.square(y.imag)


class vectorize_swig_psd_func:  # noqa: N801
    """Create a vectorized Numpy function from a SWIG-wrapped PSD function.
    SWIG does not provide enough information for Numpy to determine the number
    of input arguments, so we can't just use np.vectorize.
    """

    def __init__(self, str):
        self.__func = getattr(lalsimulation, str + 'Ptr')
        self.__npyfunc = np.frompyfunc(getattr(lalsimulation, str), 1, 1)

    def __call__(self, f):
        fa = np.asarray(f)
        df = np.diff(fa)
        if fa.ndim == 1 and df.size > 1 and np.all(df[0] == df[1:]):
            fa = np.concatenate((fa, [fa[-1] + df[0]]))
            ret = lal.CreateREAL8FrequencySeries(
                None, 0, fa[0], df[0], lal.DimensionlessUnit, fa.size)
            lalsimulation.SimNoisePSD(ret, 0, self.__func)
            ret = ret.data.data[:-1]
        else:
            ret = self.__npyfunc(f)
        if not np.isscalar(ret):
            ret = ret.astype(float)
        return ret


class InterpolatedPSD(interpolate.interp1d):
    """Create a (linear in log-log) interpolating function for a discretely
    sampled power spectrum S(f).
    """

    def __init__(self, f, S, f_high_truncate=1.0, fill_value=np.inf):
        assert f_high_truncate <= 1.0
        f = np.asarray(f)
        S = np.asarray(S)

        # Exclude DC if present
        if f[0] == 0:
            f = f[1:]
            S = S[1:]
        # FIXME: This is a hack to fix an issue with the detection pipeline's
        # PSD conditioning. Remove this when the issue is fixed upstream.
        if f_high_truncate < 1.0:
            log.warning(
                'Truncating PSD at %g of maximum frequency to suppress '
                'rolloff artifacts. This option may be removed in the future.',
                f_high_truncate)
            keep = (f <= f_high_truncate * max(f))
            f = f[keep]
            S = S[keep]
        super().__init__(
            np.log(f), np.log(S),
            kind='linear', bounds_error=False, fill_value=np.log(fill_value))
        self._f_min = min(f)
        self._f_max = max(f)

    @property
    def f_min(self):
        return self._f_min

    @property
    def f_max(self):
        return self._f_max

    def __call__(self, f):
        f_min = np.min(f)
        f_max = np.max(f)
        if f_min < self._f_min:
            log.warning('Assuming PSD is infinite at %g Hz because PSD is '
                        'only sampled down to %g Hz', f_min, self._f_min)
        if f_max > self._f_max:
            log.warning('Assuming PSD is infinite at %g Hz because PSD is '
                        'only sampled up to %g Hz', f_max, self._f_max)
        return np.where(
            (f >= self._f_min) & (f <= self._f_max),
            np.exp(super().__call__(np.log(f))),
            np.exp(self.fill_value))


class SignalModel:
    """Class to speed up computation of signal/noise-weighted integrals and
    Barankin and Cramér-Rao lower bounds on time and phase estimation.

    Note that the autocorrelation series and the moments are related,
    as shown below.

    Examples
    --------
    Create signal model:

    >>> from . import filter
    >>> sngl = lambda: None
    >>> H = filter.sngl_inspiral_psd(
    ...     'TaylorF2threePointFivePN', mass1=1.4, mass2=1.4)
    >>> S = vectorize_swig_psd_func('SimNoisePSDaLIGOZeroDetHighPower')
    >>> W = filter.signal_psd_series(H, S)
    >>> sm = SignalModel(W)

    Compute one-sided autocorrelation function:

    >>> out_duration = 0.1
    >>> a, sample_rate = filter.autocorrelation(W, out_duration)

    Restore negative time lags using symmetry:

    >>> a = np.concatenate((a[:0:-1].conj(), a))

    Compute the first 2 frequency moments by taking derivatives of the
    autocorrelation sequence using centered finite differences.
    The nth frequency moment should be given by (-1j)^n a^(n)(t).

    >>> acor_moments = []
    >>> for i in range(2):
    ...     acor_moments.append(a[len(a) // 2])
    ...     a = -0.5j * sample_rate * (a[2:] - a[:-2])
    >>> assert np.all(np.isreal(acor_moments))
    >>> acor_moments = np.real(acor_moments)

    Compute the first 2 frequency moments using this class.

    >>> quad_moments = [sm.get_sn_moment(i) for i in range(2)]

    Compare them.

    >>> for i, (am, qm) in enumerate(zip(acor_moments, quad_moments)):
    ...     assert np.allclose(am, qm, rtol=0.05)

    """

    def __init__(self, h):
        """Create a TaylorF2 signal model with the given masses, PSD function
        S(f), PN amplitude order, and low-frequency cutoff.
        """
        # Find indices of first and last nonzero samples.
        nonzero = np.flatnonzero(h.data.data)
        first_nonzero = nonzero[0]
        last_nonzero = nonzero[-1]

        # Frequency sample points
        self.dw = 2 * np.pi * h.deltaF
        f = h.f0 + h.deltaF * np.arange(first_nonzero, last_nonzero + 1)
        self.w = 2 * np.pi * f

        # Throw away leading and trailing zeros.
        h = h.data.data[first_nonzero:last_nonzero + 1]

        self.denom_integrand = 4 / (2 * np.pi) * h
        self.den = np.trapz(self.denom_integrand, dx=self.dw)

    def get_horizon_distance(self, snr_thresh=1):
        return np.sqrt(self.den) / snr_thresh

    def get_sn_average(self, func):
        """Get the average of a function of angular frequency, weighted by the
        signal to noise per unit angular frequency.
        """
        num = np.trapz(func(self.w) * self.denom_integrand, dx=self.dw)
        return num / self.den

    def get_sn_moment(self, power):
        """Get the average of angular frequency to the given power, weighted by
        the signal to noise per unit frequency.
        """
        return self.get_sn_average(lambda w: w**power)

    def get_crb(self, snr):
        """Get the Cramér-Rao bound, or inverse Fisher information matrix,
        describing the phase and time estimation covariance.
        """
        w1 = self.get_sn_moment(1)
        w2 = self.get_sn_moment(2)
        fisher = np.asarray(((1, -w1), (-w1, w2)))
        return linalg.inv(fisher) / np.square(snr)

    # FIXME: np.vectorize doesn't work on unbound instance methods. The
    # excluded keyword, added in Numpy 1.7, could be used here to exclude the
    # zeroth argument, self.
    def __get_crb_toa_uncert(self, snr):
        return np.sqrt(self.get_crb(snr)[1, 1])

    def get_crb_toa_uncert(self, snr):
        return np.frompyfunc(self.__get_crb_toa_uncert, 1, 1)(snr)

ligo/skymap/bayestar/interpolation.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Sub-sample interpolation for matched filter time series.

Example
-------
.. plot::
   :context: reset
   :include-source:
   :align: center

    from ligo.skymap.bayestar.interpolation import interpolate_max
    from matplotlib import pyplot as plt
    import numpy as np

    z = np.asarray([ 9.135017 -2.8185585j,  9.995214 -1.1222992j,
                    10.682851 +0.8188147j, 10.645139 +3.0268786j,
                     9.713133 +5.5589147j,  7.9043484+7.9039335j,
                     5.511646 +9.333084j ,  2.905198 +9.715742j ,
                     0.5302934+9.544538j ])

    amp = np.abs(z)
    arg = np.rad2deg(np.unwrap(np.angle(z)))
    arg -= (np.median(arg) // 360) * 360
    imax = np.argmax(amp)
    window = 4

    fig, (ax_amp, ax_arg) = plt.subplots(2, 1, figsize=(5, 6), sharex=True)
    ax_arg.set_xlabel('Sample index')
    ax_amp.set_ylabel('Amplitude')
    ax_arg.set_ylabel('Phase')
    args, kwargs = ('.-',), dict(color='lightgray', label='data')
    ax_amp.plot(amp, *args, **kwargs)
    ax_arg.plot(arg, *args, **kwargs)
    for method in ['lanczos', 'catmull-rom',
                   'quadratic-fit', 'nearest-neighbor']:
        i, y = interpolate_max(imax, z, window, method)
        amp = np.abs(y)
        arg = np.rad2deg(np.angle(y))
        args, kwargs = ('o',), dict(mfc='none', label=method)
        ax_amp.plot(i, amp, *args, **kwargs)
        ax_arg.plot(i, arg, *args, **kwargs)
    ax_arg.legend()
    fig.tight_layout()

"""
import numpy as np
from scipy import optimize

from .filter import abs2, exp_i, unwrap

__all__ = ('interpolate_max',)


#
# Lanczos interpolation
#


def lanczos(t, a):
    """The Lanczos kernel."""
    return np.where(np.abs(t) < a, np.sinc(t) * np.sinc(t / a), 0)


def lanczos_interpolant(t, y):
    """An interpolant constructed by convolution of the Lanczos kernel with
    a set of discrete samples at unit intervals.
    """
    a = len(y) // 2
    return sum(lanczos(t - i + a, a) * yi for i, yi in enumerate(y))


def lanczos_interpolant_utility_func(t, y):
    """Utility function for Lanczos interpolation."""
    return -abs2(lanczos_interpolant(t, y))


def interpolate_max_lanczos(imax, y, window_length):
    """Find the time and maximum absolute value of a time series by Lanczos
    interpolation.
    """
    yi = y[(imax - window_length):(imax + window_length + 1)]
    tmax = optimize.fminbound(
        lanczos_interpolant_utility_func, -1., 1., (yi,), xtol=1e-5)
    tmax = tmax.item()
    ymax = lanczos_interpolant(tmax, yi).item()
    return imax + tmax, ymax


#
# Catmull-Rom spline interpolation, real and imaginary parts
#


def poly_catmull_rom(y):
    return np.poly1d([
        -0.5 * y[0] + 1.5 * y[1] - 1.5 * y[2] + 0.5 * y[3],
        y[0] - 2.5 * y[1] + 2 * y[2] - 0.5 * y[3],
        -0.5 * y[0] + 0.5 * y[2],
        y[1]
    ])


def interpolate_max_catmull_rom_even(y):

    # Construct Catmull-Rom interpolating polynomials for
    # real and imaginary parts
    poly_re = poly_catmull_rom(y.real)
    poly_im = poly_catmull_rom(y.imag)

    # Find the roots of d(|y|^2)/dt as approximated
    roots = (poly_re * poly_re.deriv() + poly_im * poly_im.deriv()).r

    # Find which of the two matched interior points has a greater magnitude
    t_max = 0.
    y_max = y[1]
    y_max_abs2 = abs2(y_max)

    new_t_max = 1.
    new_y_max = y[2]
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    # Find any real root in (0, 1) that has a magnitude greater than the
    # greatest endpoint
    for root in roots:
        if np.isreal(root) and 0 < root < 1:
            new_t_max = np.real(root)
            new_y_max = poly_re(new_t_max) + poly_im(new_t_max) * 1j
            new_y_max_abs2 = abs2(new_y_max)
            if new_y_max_abs2 > y_max_abs2:
                t_max = new_t_max
                y_max = new_y_max
                y_max_abs2 = new_y_max_abs2

    # Done
    return t_max, y_max


def interpolate_max_catmull_rom(imax, y, window_length):
    t_max, y_max = interpolate_max_catmull_rom_even(y[imax - 2:imax + 2])
    y_max_abs2 = abs2(y_max)
    t_max = t_max - 1

    new_t_max, new_y_max = interpolate_max_catmull_rom_even(
        y[imax - 1:imax + 3])
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    return imax + t_max, y_max


#
# Catmull-Rom spline interpolation, amplitude and phase
#


def interpolate_max_catmull_rom_amp_phase_even(y):

    # Construct Catmull-Rom interpolating polynomials for
    # real and imaginary parts
    poly_abs = poly_catmull_rom(np.abs(y))
    poly_arg = poly_catmull_rom(unwrap(np.angle(y)))

    # Find the roots of d(|y|)/dt as approximated
    roots = poly_abs.r

    # Find which of the two matched interior points has a greater magnitude
    t_max = 0.
    y_max = y[1]
    y_max_abs2 = abs2(y_max)

    new_t_max = 1.
    new_y_max = y[2]
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    # Find any real root in (0, 1) that has a magnitude greater than the
    # greatest endpoint
    for root in roots:
        if np.isreal(root) and 0 < root < 1:
            new_t_max = np.real(root)
            new_y_max = poly_abs(new_t_max) * exp_i(poly_arg(new_t_max))
            new_y_max_abs2 = abs2(new_y_max)
            if new_y_max_abs2 > y_max_abs2:
                t_max = new_t_max
                y_max = new_y_max
                y_max_abs2 = new_y_max_abs2

    # Done
    return t_max, y_max


def interpolate_max_catmull_rom_amp_phase(imax, y, window_length):
    t_max, y_max = interpolate_max_catmull_rom_amp_phase_even(
        y[imax - 2:imax + 2])
    y_max_abs2 = abs2(y_max)
    t_max = t_max - 1

    new_t_max, new_y_max = interpolate_max_catmull_rom_amp_phase_even(
        y[imax - 1:imax + 3])
    new_y_max_abs2 = abs2(new_y_max)

    if new_y_max_abs2 > y_max_abs2:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs2 = new_y_max_abs2

    return imax + t_max, y_max


#
# Quadratic fit
#


def interpolate_max_quadratic_fit(imax, y, window_length):
    """Quadratic fit to absolute value of y. Note that this one does not alter
    the value at the maximum.
    """
    t = np.arange(-window_length, window_length + 1.)
    y = y[imax - window_length:imax + window_length + 1]
    y_abs = np.abs(y)
    a, b, c = np.polyfit(t, y_abs, 2)

    # Find which of the two matched interior points has a greater magnitude
    t_max = -1.
    y_max = y[window_length - 1]
    y_max_abs = y_abs[window_length - 1]

    new_t_max = 1.
    new_y_max = y[window_length + 1]
    new_y_max_abs = y_abs[window_length + 1]

    if new_y_max_abs > y_max_abs:
        t_max = new_t_max
        y_max = new_y_max
        y_max_abs = new_y_max_abs

    # Determine if the global extremum of the polynomial is a
    # local maximum in (-1, 1)
    new_t_max = -0.5 * b / a
    new_y_max_abs = c - 0.25 * np.square(b) / a
    if -1 < new_t_max < 1 and new_y_max_abs > y_max_abs:
        t_max = new_t_max
        y_max_abs = new_y_max_abs
        y_phase = np.interp(t_max, t, np.unwrap(np.angle(y)))
        y_max = y_max_abs * exp_i(y_phase)

    return imax + t_max, y_max


#
# Nearest neighbor interpolation
#


def interpolate_max_nearest_neighbor(imax, y, window_length):
    """Trivial, nearest-neighbor interpolation."""
    return imax, y[imax]


#
# Set default interpolation scheme
#


_interpolants = {
    'catmull-rom-amp-phase': interpolate_max_catmull_rom_amp_phase,
    'catmull-rom': interpolate_max_catmull_rom,
    'lanczos': interpolate_max_lanczos,
    'nearest-neighbor': interpolate_max_nearest_neighbor,
    'quadratic-fit': interpolate_max_quadratic_fit}


def interpolate_max(imax, y, window_length, method='catmull-rom-amp-phase'):
    """Perform sub-sample interpolation to find the phase and amplitude
    at the maximum of the absolute value of a complex series.

    Parameters
    ----------
    imax : int
        The index of the maximum sample in the series.
    y : `numpy.ndarray`
        The complex series.
    window_length : int
        The window of the interpolation function for the `lanczos` and
        `quadratic-fit` methods. The interpolation will consider a sliding
        window of `2 * window_length + 1` samples centered on `imax`.
    method : {'catmull-rom-amp-phase', 'catmull-rom', 'lanczos', 'nearest-neighbor', 'quadratic-fit'}
        The interpolation method. Can be any of the following:

        * `catmull-rom-amp-phase`: Catmull-Rom cubic splines on amplitude and phase
          The `window_length` parameter is ignored (understood to be 2).
        * `catmull-rom`: Catmull-Rom cubic splines on real and imaginary parts
          The `window_length` parameter is ignored (understood to be 2).
        * `lanczos`: Lanczos filter interpolation
        * `nearest-neighbor`: Nearest neighbor (e.g., no interpolation).
          The `window_length` parameter is ignored (understood to be 0).
        * `quadratic-fit`: Fit the absolute value of the SNR to a quadratic
          function and the phase to a linear function.

    Returns
    -------
    imax_interp : float
        The interpolated index of the maximum sample, which should be between
        `imax - 0.5` and `imax + 0.5`.
    ymax_interp : complex
        The interpolated value at the maximum.

    """  # noqa: E501
    return _interpolants[method](imax, np.asarray(y), window_length)

ligo/skymap/bayestar/ptemcee.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
# FIXME: Remove this file if https://github.com/willvousden/ptemcee/pull/6
# is merged
import numpy as np
import ptemcee.sampler

__all__ = ('Sampler',)


class VectorLikePriorEvaluator(ptemcee.sampler.LikePriorEvaluator):

    def __call__(self, x):
        s = x.shape
        x = x.reshape((-1, x.shape[-1]))

        lp = self.logp(x, *self.logpargs, **self.logpkwargs)
        if np.any(np.isnan(lp)):
            raise ValueError('Prior function returned NaN.')

        ll = np.empty_like(lp)
        bad = (lp == -np.inf)
        ll[bad] = 0
        ll[~bad] = self.logl(x[~bad], *self.loglargs, **self.loglkwargs)
        if np.any(np.isnan(ll)):
            raise ValueError('Log likelihood function returned NaN.')

        return ll.reshape(s[:-1]), lp.reshape(s[:-1])


class Sampler(ptemcee.sampler.Sampler):
    """Patched version of :class:`ptemcee.Sampler` that supports the
    `vectorize` option of :class:`emcee.EnsembleSampler`.
    """

    def __init__(self, nwalkers, dim, logl, logp,  # noqa: N803
                 ntemps=None, Tmax=None, betas=None,  # noqa: N803
                 threads=1, pool=None, a=2.0,
                 loglargs=[], logpargs=[],
                 loglkwargs={}, logpkwargs={},
                 adaptation_lag=10000, adaptation_time=100,
                 random=None, vectorize=False):
        super().__init__(nwalkers, dim, logl, logp,
                         ntemps=ntemps, Tmax=Tmax, betas=betas,
                         threads=threads, pool=pool, a=a, loglargs=loglargs,
                         logpargs=logpargs, loglkwargs=loglkwargs,
                         logpkwargs=logpkwargs, adaptation_lag=adaptation_lag,
                         adaptation_time=adaptation_time, random=random)
        self._vectorize = vectorize
        if vectorize:
            self._likeprior = VectorLikePriorEvaluator(logl, logp,
                                                       loglargs, logpargs,
                                                       loglkwargs, logpkwargs)

    def _evaluate(self, ps):
        if self._vectorize:
            return self._likeprior(ps)
        else:
            return super().evaluate(ps)

ligo/skymap/coordinates/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/coordinates/detector.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Astropy coordinate frames to visualize triangulation rings from pairs of
gravitational-wave detectors. These are useful for generating plots similar to
Fig. 2 of the GW150914 localization and follow-up paper [1]_.

Example
-------
.. plot::
   :context: reset
   :include-source:
   :align: center

    from astropy.coordinates import EarthLocation
    from astropy.time import Time
    from ligo.skymap.coordinates import DetectorFrame
    from ligo.skymap.io import read_sky_map
    import ligo.skymap.plot
    from matplotlib import pyplot as plt

    # Download GW150914 localization
    url = 'https://dcc.ligo.org/public/0122/P1500227/012/bayestar_gstlal_C01.fits.gz'
    m, meta = ligo.skymap.io.read_sky_map(url)

    # Plot sky map on an orthographic projection
    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(
        111, projection='astro globe', center='130d -70d')
    ax.imshow_hpx(m, cmap='cylon')

    # Hide the original ('RA', 'Dec') ticks
    for coord in ax.coords:
        coord.set_ticks_visible(False)
        coord.set_ticklabel_visible(False)

    # Construct Hanford-Livingston detector frame at the time of the event
    frame = DetectorFrame(site_1=EarthLocation.of_site('H1'),
                          site_2=EarthLocation.of_site('L1'),
                          obstime=Time(meta['gps_time'], format='gps'))

    # Draw grid for detector frame
    ax.get_coords_overlay(frame).grid()

References
----------
.. [1] LSC/Virgo et al., 2016. "Localization and Broadband Follow-up of the
       Gravitational-wave Transient GW150914." ApJL 826, L13.
       :doi:`10.3847/2041-8205/826/1/L13`

"""  # noqa: E501
from astropy.coordinates import (
    CartesianRepresentation, DynamicMatrixTransform, EarthLocation,
    EarthLocationAttribute, frame_transform_graph, ITRS,
    SphericalRepresentation)
from astropy.coordinates.matrix_utilities import matrix_transpose
from astropy import units as u
import lal
import numpy as np

__all__ = ('DetectorFrame',)

# Add gravitational-wave detectors to site registry
registry = EarthLocation._get_site_registry()
for detector in lal.CachedDetectors:
    names = [detector.frDetector.name, detector.frDetector.prefix]
    location = EarthLocation(*detector.location, unit=u.m)
    registry.add_site(names, location)
    del names, detector
del lal, registry


class DetectorFrame(ITRS):
    """A coordinate frames to visualize triangulation rings from pairs of
    gravitational-wave detectors.
    """

    site_1 = EarthLocationAttribute()
    site_2 = EarthLocationAttribute()

    default_representation = SphericalRepresentation


@frame_transform_graph.transform(DynamicMatrixTransform, ITRS, DetectorFrame)
def itrs_to_detectorframe(from_coo, to_frame):
    e_z = CartesianRepresentation(u.Quantity(to_frame.site_1.geocentric) -
                                  u.Quantity(to_frame.site_2.geocentric))
    e_z /= e_z.norm()
    e_x = CartesianRepresentation(0, 0, 1).cross(e_z)
    e_x /= e_x.norm()
    e_y = e_z.cross(e_x)

    return np.row_stack((e_x.xyz.value,
                         e_y.xyz.value,
                         e_z.xyz.value))


@frame_transform_graph.transform(DynamicMatrixTransform, DetectorFrame, ITRS)
def detectorframe_to_itrs(from_coo, to_frame):
    return matrix_transpose(itrs_to_detectorframe(to_frame, from_coo))

ligo/skymap/coordinates/eigenframe.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
#
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Astropy coordinate frame for eigendecomposition of a cloud of points or a 3D
sky map.
"""

from astropy.coordinates import (
    BaseCoordinateFrame, CartesianRepresentation, DynamicMatrixTransform,
    frame_transform_graph, ICRS, SphericalRepresentation)
from astropy.coordinates import CartesianRepresentationAttribute
from astropy.units import dimensionless_unscaled
import numpy as np

from ..distance import principal_axes

__all__ = ('EigenFrame',)


class EigenFrame(BaseCoordinateFrame):
    """A coordinate frame that has its axes aligned with the principal
    components of a cloud of points.
    """

    e_x = CartesianRepresentationAttribute(
        default=CartesianRepresentation(1, 0, 0, unit=dimensionless_unscaled),
        unit=dimensionless_unscaled)
    e_y = CartesianRepresentationAttribute(
        default=CartesianRepresentation(0, 1, 0, unit=dimensionless_unscaled),
        unit=dimensionless_unscaled)
    e_z = CartesianRepresentationAttribute(
        default=CartesianRepresentation(0, 0, 1, unit=dimensionless_unscaled),
        unit=dimensionless_unscaled)

    default_representation = SphericalRepresentation

    @classmethod
    def for_coords(cls, coords):
        """Create a coordinate frame that has its axes aligned with the
        principal components of a cloud of points.

        Parameters
        ----------
        coords : `astropy.coordinates.SkyCoord`
            A cloud of points

        Returns
        -------
        frame : `EigenFrame`
            A new coordinate frame

        """
        v = coords.icrs.cartesian.xyz.value
        _, r = np.linalg.eigh(np.dot(v, v.T))
        r = r[:, ::-1]  # Order by descending eigenvalue
        e_x, e_y, e_z = CartesianRepresentation(r, unit=dimensionless_unscaled)
        return cls(e_x=e_x, e_y=e_y, e_z=e_z)

    @classmethod
    def for_skymap(cls, prob, distmu, distsigma, nest=False):
        """Create a coordinate frame that has its axes aligned with the
        principal components of a 3D sky map.

        Parameters
        ----------
        prob : `numpy.ndarray`
            Marginal probability (pix^-2)
        distmu : `numpy.ndarray`
            Distance location parameter (Mpc)
        distsigma : `numpy.ndarray`
            Distance scale parameter (Mpc)
        distnorm : `numpy.ndarray`
            Distance normalization factor (Mpc^-2)
        nest : bool, default=False
            Indicates whether the input sky map is in nested rather than
            ring-indexed HEALPix coordinates (default: ring).

        Returns
        -------
        frame : `EigenFrame`
            A new coordinate frame

        """
        r = principal_axes(prob, distmu, distsigma, nest=nest)
        r = r[:, ::-1]  # Order by descending eigenvalue
        e_x, e_y, e_z = CartesianRepresentation(r, unit=dimensionless_unscaled)
        return cls(e_x=e_x, e_y=e_y, e_z=e_z)


@frame_transform_graph.transform(DynamicMatrixTransform, ICRS, EigenFrame)
def icrs_to_eigenframe(from_coo, to_frame):
    return np.row_stack((to_frame.e_x.xyz.value,
                         to_frame.e_y.xyz.value,
                         to_frame.e_z.xyz.value))


@frame_transform_graph.transform(DynamicMatrixTransform, EigenFrame, ICRS)
def eigenframe_to_icrs(from_coo, to_frame):
    return np.column_stack((from_coo.e_x.xyz.value,
                            from_coo.e_y.xyz.value,
                            from_coo.e_z.xyz.value))

ligo/skymap/io/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/io/fits.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
#!/usr/bin/env python
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Reading and writing HEALPix FITS files.

An example FITS header looks like this:

.. code-block:: sh

    $ fitsheader test.fits.gz
    # HDU 0 in test.fits.gz
    SIMPLE  =                    T / conforms to FITS standard
    BITPIX  =                    8 / array data type
    NAXIS   =                    0 / number of array dimensions
    EXTEND  =                    T

    # HDU 1 in test.fits.gz
    XTENSION= 'BINTABLE'           / binary table extension
    BITPIX  =                    8 / array data type
    NAXIS   =                    2 / number of array dimensions
    NAXIS1  =                 4096 / length of dimension 1
    NAXIS2  =                  192 / length of dimension 2
    PCOUNT  =                    0 / number of group parameters
    GCOUNT  =                    1 / number of groups
    TFIELDS =                    1 / number of table fields
    TTYPE1  = 'PROB    '
    TFORM1  = '1024E   '
    TUNIT1  = 'pix-1   '
    PIXTYPE = 'HEALPIX '           / HEALPIX pixelisation
    ORDERING= 'RING    '           / Pixel ordering scheme, either RING or NESTED
    COORDSYS= 'C       '           / Ecliptic, Galactic or Celestial (equatorial)
    EXTNAME = 'xtension'           / name of this binary table extension
    NSIDE   =                  128 / Resolution parameter of HEALPIX
    FIRSTPIX=                    0 / First pixel # (0 based)
    LASTPIX =               196607 / Last pixel # (0 based)
    INDXSCHM= 'IMPLICIT'           / Indexing: IMPLICIT or EXPLICIT
    OBJECT  = 'FOOBAR 12345'       / Unique identifier for this event
    REFERENC= 'http://www.youtube.com/watch?v=0ccKPSVQcFk' / URL of this event
    DATE-OBS= '2013-04-08T21:37:32.25' / UTC date of the observation
    MJD-OBS =      56391.151064815 / modified Julian date of the observation
    DATE    = '2013-04-08T21:50:32' / UTC date of file creation
    CREATOR = 'fits.py '           / Program that created this file
    RUNTIME =                 21.5 / Runtime in seconds of the CREATOR program
"""  # noqa: E501

import logging
import healpy as hp
import numpy as np
from astropy.io import fits
from astropy.time import Time
from astropy import units as u
from ligo.lw import lsctables, ilwd
import itertools
import astropy_healpix as ah
from astropy.table import Table
from .. import moc

log = logging.getLogger()

__all__ = ("read_sky_map", "write_sky_map")


def gps_to_iso8601(gps_time):
    """Convert a floating-point GPS time in seconds to an ISO 8601 date string.

    Parameters
    ----------
    gps : float
        Time in seconds since GPS epoch

    Returns
    -------
    iso8601 : str
        ISO 8601 date string (with fractional seconds)

    Examples
    --------
    >>> gps_to_iso8601(1000000000.01)
    '2011-09-14T01:46:25.010000'
    >>> gps_to_iso8601(1000000000)
    '2011-09-14T01:46:25.000000'
    >>> gps_to_iso8601(1000000000.999999)
    '2011-09-14T01:46:25.999999'
    >>> gps_to_iso8601(1000000000.9999999)
    '2011-09-14T01:46:26.000000'
    >>> gps_to_iso8601(1000000814.999999)
    '2011-09-14T01:59:59.999999'
    >>> gps_to_iso8601(1000000814.9999999)
    '2011-09-14T02:00:00.000000'

    """
    return Time(float(gps_time), format='gps', precision=6).utc.isot


def iso8601_to_gps(iso8601):
    """Convert an ISO 8601 date string to a floating-point GPS time in seconds.

    Parameters
    ----------
    iso8601 : str
        ISO 8601 date string (with fractional seconds)

    Returns
    -------
    gps : float
        Time in seconds since GPS epoch

    Examples
    --------
    >>> gps_to_iso8601(1129501781.2)
    '2015-10-21T22:29:24.200000'
    >>> iso8601_to_gps('2015-10-21T22:29:24.2')
    1129501781.2

    """
    return Time(iso8601, scale='utc').gps


def gps_to_mjd(gps_time):
    """Convert a floating-point GPS time in seconds to a modified Julian day.

    Parameters
    ----------
    gps_time : float
        Time in seconds since GPS epoch

    Returns
    -------
    mjd : float
        Modified Julian day

    Examples
    --------
    >>> '%.9f' % round(gps_to_mjd(1129501781.2), 9)
    '57316.937085648'

    """
    return Time(gps_time, format='gps').utc.mjd


def identity(x):
    return x


def instruments_to_fits(value):
    if not isinstance(value, str):
        value = str(lsctables.instrumentsproperty.set(value))
    return value


def instruments_from_fits(value):
    return {str(ifo) for ifo in lsctables.instrumentsproperty.get(value)}


def metadata_for_version_module(version):
    return {'vcs_version': version.__spec__.parent + ' ' + version.version}


def normalize_objid(objid):
    try:
        return int(objid)
    except ValueError:
        try:
            return int(ilwd.ilwdchar(objid))
        except ValueError:
            return str(objid)


DEFAULT_NUNIQ_NAMES = ('PROBDENSITY', 'DISTMU', 'DISTSIGMA', 'DISTNORM')
DEFAULT_NUNIQ_UNITS = (u.steradian**-1, u.Mpc, u.Mpc, u.Mpc**-2)
DEFAULT_NESTED_NAMES = ('PROB', 'DISTMU', 'DISTSIGMA', 'DISTNORM')
DEFAULT_NESTED_UNITS = (u.pix**-1, u.Mpc, u.Mpc, u.Mpc**-2)
FITS_META_MAPPING = (
    ('objid', 'OBJECT', 'Unique identifier for this event',
     normalize_objid, normalize_objid),
    ('url', 'REFERENC', 'URL of this event', identity, identity),
    ('instruments', 'INSTRUME', 'Instruments that triggered this event',
     instruments_to_fits, instruments_from_fits),
    ('gps_time', 'DATE-OBS', 'UTC date of the observation',
     gps_to_iso8601, iso8601_to_gps),
    ('gps_time', 'MJD-OBS', 'modified Julian date of the observation',
     gps_to_mjd, None),
    ('gps_creation_time', 'DATE', 'UTC date of file creation',
     gps_to_iso8601, iso8601_to_gps),
    ('creator', 'CREATOR', 'Program that created this file',
     identity, identity),
    ('origin', 'ORIGIN', 'Organization responsible for this FITS file',
     identity, identity),
    ('runtime', 'RUNTIME', 'Runtime in seconds of the CREATOR program',
     identity, identity),
    ('distmean', 'DISTMEAN', 'Posterior mean distance (Mpc)',
     identity, identity),
    ('diststd', 'DISTSTD', 'Posterior standard deviation of distance (Mpc)',
     identity, identity),
    ('log_bci', 'LOGBCI', 'Log Bayes factor: coherent vs. incoherent',
     identity, identity),
    ('log_bsn', 'LOGBSN', 'Log Bayes factor: signal vs. noise',
     identity, identity),
    ('vcs_version', 'VCSVERS', 'Software version',
     identity, identity),
    ('vcs_revision', 'VCSREV', 'Software revision (Git)',
     identity, identity),
    ('build_date', 'DATE-BLD', 'Software build date',
     identity, identity))


def write_sky_map(filename, m, **kwargs):
    """Write a gravitational-wave sky map to a file, populating the header
    with optional metadata.

    Parameters
    ----------
    filename: str
        Path to the optionally gzip-compressed FITS file.

    m : `astropy.table.Table`, `numpy.array`
        If a Numpy record array or astorpy.table.Table instance, and has a
        column named 'UNIQ', then interpret the input as NUNIQ-style
        multi-order map [1]_. Otherwise, interpret as a NESTED or RING ordered
        map.

    **kwargs
        Additional metadata to add to FITS header. If m is an
        `astropy.table.Table` instance, then the header is initialized from
        both `m.meta` and `kwargs`.

    References
    ----------
    .. [1] Górski, K.M., Wandelt, B.D., Hivon, E., Hansen, F.K., & Banday, A.J.
        2017. The HEALPix Primer. The Unique Identifier scheme.
        http://healpix.sourceforge.net/html/intronode4.htm#SECTION00042000000000000000

    Examples
    --------
    Test header contents:

    >>> order = 9
    >>> nside = 2 ** order
    >>> npix = ah.nside_to_npix(nside)
    >>> prob = np.ones(npix, dtype=np.float) / npix

    >>> import tempfile
    >>> from ligo.skymap import version
    >>> with tempfile.NamedTemporaryFile(suffix='.fits') as f:
    ...     write_sky_map(f.name, prob, nest=True,
    ...                   vcs_version='foo 1.0', vcs_revision='bar',
    ...                   build_date='2018-01-01T00:00:00')
    ...     for card in fits.getheader(f.name, 1).cards:
    ...         print(str(card).rstrip())
    XTENSION= 'BINTABLE'           / binary table extension
    BITPIX  =                    8 / array data type
    NAXIS   =                    2 / number of array dimensions
    NAXIS1  =                    8 / length of dimension 1
    NAXIS2  =              3145728 / length of dimension 2
    PCOUNT  =                    0 / number of group parameters
    GCOUNT  =                    1 / number of groups
    TFIELDS =                    1 / number of table fields
    TTYPE1  = 'PROB    '
    TFORM1  = 'D       '
    TUNIT1  = 'pix-1   '
    PIXTYPE = 'HEALPIX '           / HEALPIX pixelisation
    ORDERING= 'NESTED  '           / Pixel ordering scheme: RING, NESTED, or NUNIQ
    COORDSYS= 'C       '           / Ecliptic, Galactic or Celestial (equatorial)
    NSIDE   =                  512 / Resolution parameter of HEALPIX
    INDXSCHM= 'IMPLICIT'           / Indexing: IMPLICIT or EXPLICIT
    VCSVERS = 'foo 1.0 '           / Software version
    VCSREV  = 'bar     '           / Software revision (Git)
    DATE-BLD= '2018-01-01T00:00:00' / Software build date

    >>> uniq = moc.nest2uniq(np.uint8(order), np.arange(npix))
    >>> probdensity = prob / hp.nside2pixarea(nside)
    >>> moc_data = np.rec.fromarrays(
    ...     [uniq, probdensity], names=['UNIQ', 'PROBDENSITY'])
    >>> with tempfile.NamedTemporaryFile(suffix='.fits') as f:
    ...     write_sky_map(f.name, moc_data,
    ...                   vcs_version='foo 1.0', vcs_revision='bar',
    ...                   build_date='2018-01-01T00:00:00')
    ...     for card in fits.getheader(f.name, 1).cards:
    ...         print(str(card).rstrip())
    XTENSION= 'BINTABLE'           / binary table extension
    BITPIX  =                    8 / array data type
    NAXIS   =                    2 / number of array dimensions
    NAXIS1  =                   16 / length of dimension 1
    NAXIS2  =              3145728 / length of dimension 2
    PCOUNT  =                    0 / number of group parameters
    GCOUNT  =                    1 / number of groups
    TFIELDS =                    2 / number of table fields
    TTYPE1  = 'UNIQ    '
    TFORM1  = 'K       '
    TTYPE2  = 'PROBDENSITY'
    TFORM2  = 'D       '
    TUNIT2  = 'sr-1    '
    PIXTYPE = 'HEALPIX '           / HEALPIX pixelisation
    ORDERING= 'NUNIQ   '           / Pixel ordering scheme: RING, NESTED, or NUNIQ
    COORDSYS= 'C       '           / Ecliptic, Galactic or Celestial (equatorial)
    MOCORDER=                    9 / MOC resolution (best order)
    INDXSCHM= 'EXPLICIT'           / Indexing: IMPLICIT or EXPLICIT
    VCSVERS = 'foo 1.0 '           / Software version
    VCSREV  = 'bar     '           / Software revision (Git)
    DATE-BLD= '2018-01-01T00:00:00' / Software build date

    """  # noqa: E501
    log.debug('normalizing metadata')
    if isinstance(m, Table) or (isinstance(m, np.ndarray) and m.dtype.names):
        m = Table(m, copy=False)
    else:
        if np.ndim(m) == 1:
            m = [m]
        m = Table(m, names=DEFAULT_NESTED_NAMES[:len(m)], copy=False)
    m.meta.update(kwargs)

    if 'UNIQ' in m.colnames:
        default_names = DEFAULT_NUNIQ_NAMES
        default_units = DEFAULT_NUNIQ_UNITS
        extra_header = [
            ('PIXTYPE', 'HEALPIX',
             'HEALPIX pixelisation'),
            ('ORDERING', 'NUNIQ',
             'Pixel ordering scheme: RING, NESTED, or NUNIQ'),
            ('COORDSYS', 'C',
             'Ecliptic, Galactic or Celestial (equatorial)'),
            ('MOCORDER', moc.uniq2order(m['UNIQ'].max()),
             'MOC resolution (best order)'),
            ('INDXSCHM', 'EXPLICIT',
             'Indexing: IMPLICIT or EXPLICIT')]
        # Ignore nest keyword argument if present
        m.meta.pop('nest', False)
    else:
        default_names = DEFAULT_NESTED_NAMES
        default_units = DEFAULT_NESTED_UNITS
        ordering = 'NESTED' if m.meta.pop('nest', False) else 'RING'
        extra_header = [
            ('PIXTYPE', 'HEALPIX',
             'HEALPIX pixelisation'),
            ('ORDERING', ordering,
             'Pixel ordering scheme: RING, NESTED, or NUNIQ'),
            ('COORDSYS', 'C',
             'Ecliptic, Galactic or Celestial (equatorial)'),
            ('NSIDE', ah.npix_to_nside(len(m)),
             'Resolution parameter of HEALPIX'),
            ('INDXSCHM', 'IMPLICIT',
             'Indexing: IMPLICIT or EXPLICIT')]

    for key, rows in itertools.groupby(FITS_META_MAPPING, lambda row: row[0]):
        try:
            value = m.meta.pop(key)
        except KeyError:
            pass
        else:
            for row in rows:
                _, fits_key, fits_comment, to_fits, _ = row
                if to_fits is not None:
                    extra_header.append(
                        (fits_key, to_fits(value), fits_comment))

    for default_name, default_unit in zip(default_names, default_units):
        try:
            col = m[default_name]
        except KeyError:
            pass
        else:
            if not col.unit:
                col.unit = default_unit

    log.debug('converting from Astropy table to FITS HDU list')
    hdu = fits.table_to_hdu(m)
    hdu.header.extend(extra_header)
    hdulist = fits.HDUList([fits.PrimaryHDU(), hdu])
    log.debug('saving')
    hdulist.writeto(filename, overwrite=True)


def read_sky_map(filename, nest=False, distances=False, moc=False, **kwargs):
    """Read a LIGO/Virgo-type sky map and return a tuple of the HEALPix array
    and a dictionary of metadata from the header.

    Parameters
    ----------
    filename: string
        Path to the optionally gzip-compressed FITS file.

    nest: bool, optional
        If omitted or False, then detect the pixel ordering in the FITS file
        and rearrange if necessary to RING indexing before returning.

        If True, then detect the pixel ordering and rearrange if necessary to
        NESTED indexing before returning.

        If None, then preserve the ordering from the FITS file.

        Regardless of the value of this option, the ordering used in the FITS
        file is indicated as the value of the 'nest' key in the metadata
        dictionary.

    distances: bool, optional
        If true, then read also read the additional HEALPix layers representing
        the conditional mean and standard deviation of distance as a function
        of sky location.

    moc: bool, optional
        If true, then preserve multi-order structure if present.

    Examples
    --------
    Test that we can read a legacy IDL-compatible file
    (https://bugs.ligo.org/redmine/issues/5168):

    >>> import tempfile
    >>> with tempfile.NamedTemporaryFile(suffix='.fits') as f:
    ...     nside = 512
    ...     npix = ah.nside_to_npix(nside)
    ...     ipix_nest = np.arange(npix)
    ...     hp.write_map(f.name, ipix_nest, nest=True, column_names=['PROB'])
    ...     m, meta = read_sky_map(f.name)
    ...     np.testing.assert_array_equal(m, hp.ring2nest(nside, ipix_nest))

    """
    m = Table.read(filename, format='fits', **kwargs)

    # Remove some keys that we do not need
    for key in (
            'PIXTYPE', 'EXTNAME', 'NSIDE', 'FIRSTPIX', 'LASTPIX', 'INDXSCHM',
            'MOCORDER'):
        m.meta.pop(key, None)

    if m.meta.pop('COORDSYS', 'C') != 'C':
        raise ValueError('ligo.skymap only reads and writes sky maps in '
                         'equatorial coordinates.')

    try:
        value = m.meta.pop('ORDERING')
    except KeyError:
        pass
    else:
        if value == 'RING':
            m.meta['nest'] = False
        elif value == 'NESTED':
            m.meta['nest'] = True
        elif value == 'NUNIQ':
            pass
        else:
            raise ValueError(
                'ORDERING card in header has unknown value: {0}'.format(value))

    for fits_key, rows in itertools.groupby(
            FITS_META_MAPPING, lambda row: row[1]):
        try:
            value = m.meta.pop(fits_key)
        except KeyError:
            pass
        else:
            for row in rows:
                key, _, _, _, from_fits = row
                if from_fits is not None:
                    m.meta[key] = from_fits(value)

    # FIXME: Fermi GBM HEALPix maps use the column name 'PROBABILITY',
    # instead of the LIGO/Virgo convention of 'PROB'.
    #
    # Fermi may change to our convention in the future, but for now we
    # rename the column.
    if 'PROBABILITY' in m.colnames:
        m.rename_column('PROBABILITY', 'PROB')

    # For a long time, we produced files with a UNIQ column that was an
    # unsigned integer. Cast it here to a signed integer so that the user
    # can handle old or new sky maps the same way.
    if 'UNIQ' in m.colnames:
        m['UNIQ'] = m['UNIQ'].astype(np.int64)

    if 'UNIQ' not in m.colnames:
        m = Table([col.ravel() for col in m.columns.values()], meta=m.meta)

    if 'UNIQ' in m.colnames and not moc:
        from ..bayestar import rasterize
        m = rasterize(m)
        m.meta['nest'] = True
    elif 'UNIQ' not in m.colnames and moc:
        from ..bayestar import derasterize
        if not m.meta['nest']:
            npix = len(m)
            nside = ah.npix_to_nside(npix)
            m = m[hp.nest2ring(nside, np.arange(npix))]
        m = derasterize(m)
        m.meta.pop('nest', None)

    if 'UNIQ' not in m.colnames:
        npix = len(m)
        nside = ah.npix_to_nside(npix)

        if nest is None:
            pass
        elif m.meta['nest'] and not nest:
            m = m[hp.ring2nest(nside, np.arange(npix))]
        elif not m.meta['nest'] and nest:
            m = m[hp.nest2ring(nside, np.arange(npix))]

    if moc:
        return m
    elif distances:
        return tuple(
            np.asarray(m[name]) for name in DEFAULT_NESTED_NAMES), m.meta
    else:
        return np.asarray(m[DEFAULT_NESTED_NAMES[0]]), m.meta


if __name__ == '__main__':
    import os
    nside = 128
    npix = ah.nside_to_npix(nside)
    prob = np.random.random(npix)
    prob /= sum(prob)

    write_sky_map(
        'test.fits.gz', prob,
        objid='FOOBAR 12345',
        gps_time=1049492268.25,
        creator=os.path.basename(__file__),
        url='http://www.youtube.com/watch?v=0ccKPSVQcFk',
        origin='LIGO Scientific Collaboration',
        runtime=21.5)

    print(read_sky_map('test.fits.gz'))

ligo/skymap/io/hdf5.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
# Copyright (C) 2016-2020  Leo Singer, John Veitch
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read HDF5 posterior sample chain HDF5 files."""

import numpy as np
import h5py
from astropy.table import Column, Table

# Constants from lalinference module
POSTERIOR_SAMPLES = 'posterior_samples'
LINEAR = 0
CIRCULAR = 1
FIXED = 2
OUTPUT = 3

__all__ = ('read_samples', 'write_samples')


def _identity(x):
    return x


_colname_map = (('rightascension', 'ra', _identity),
                ('right_ascension', 'ra', _identity),
                ('declination', 'dec', _identity),
                ('logdistance', 'dist', np.exp),
                ('distance', 'dist', _identity),
                ('luminosity_distance', 'dist', _identity),
                ('polarisation', 'psi', _identity),
                ('chirpmass', 'mc', _identity),
                ('chirp_mass', 'mc', _identity),
                ('a_spin1', 'a1', _identity),
                ('a_1', 'a1', _identity),
                ('a_spin2', 'a2', _identity),
                ('a_2', 'a2', _identity),
                ('tilt_spin1', 'tilt1', _identity),
                ('tilt_1', 'tilt1', _identity),
                ('tilt_spin2', 'tilt2', _identity),
                ('tilt_2', 'tilt2', _identity),
                ('geocent_time', 'time', _identity))


def _remap_colnames(table):
    for old_name, new_name, func in _colname_map:
        if old_name in table.colnames:
            table[new_name] = func(table.columns.pop(old_name))


def _find_table(group, tablename):
    """Recursively search an HDF5 group or file for a dataset by name.

    Parameters
    ----------
    group : `h5py.File` or `h5py.Group`
        The file or group to search
    tablename : str
        The name of the table to search for

    Returns
    -------
    dataset : `h5py.Dataset`
        The dataset whose name is `tablename`

    Raises
    ------
    KeyError
        If the table is not found or if multiple matching tables are found

    Examples
    --------
    Check that we can find a file by name:

    >>> import os.path
    >>> import tempfile
    >>> table = Table(np.eye(3), names=['a', 'b', 'c'])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'bar')
    <HDF5 dataset "bar": shape (3,), type "|V24">

    Check that an exception is raised if the table is not found:

    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'plugh')
    Traceback (most recent call last):
        ...
    KeyError: 'Table not found: plugh'

    Check that an exception is raised if multiple tables are found:

    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'bat')
    Traceback (most recent call last):
        ...
    KeyError: 'Multiple tables called bat exist: foo/bat, foo/xyzzy/bat'

    """
    results = {}

    def visitor(key, value):
        _, _, name = key.rpartition('/')
        if name == tablename:
            results[key] = value

    group.visititems(visitor)

    if len(results) == 0:
        raise KeyError('Table not found: {0}'.format(tablename))

    if len(results) > 1:
        raise KeyError('Multiple tables called {0} exist: {1}'.format(
            tablename, ', '.join(sorted(results.keys()))))

    table, = results.values()
    return table


def read_samples(filename, path=None, tablename=POSTERIOR_SAMPLES):
    """Read an HDF5 sample chain file.

    Parameters
    ----------
    filename : str
        The path of the HDF5 file on the filesystem.
    path : str, optional
        The path of the dataset within the HDF5 file.
    tablename : str, optional
        The name of table to search for recursively within the HDF5 file.
        By default, search for 'posterior_samples'.

    Returns
    -------
    chain : `astropy.table.Table`
        The sample chain as an Astropy table.

    Examples
    --------
    Test reading a file written using the Python API:

    >>> import os.path
    >>> import tempfile
    >>> table = Table([
    ...     Column(np.ones(10), name='foo', meta={'vary': FIXED}),
    ...     Column(np.arange(10), name='bar', meta={'vary': LINEAR}),
    ...     Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}),
    ...     Column(np.arange(10), name='baz', meta={'vary': OUTPUT})
    ... ])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     write_samples(table, filename, path='foo/bar/posterior_samples')
    ...     len(read_samples(filename))
    10

    Test reading a file that was written using the LAL HDF5 C API:

    >>> from pkg_resources import resource_filename
    >>> filename = resource_filename(__name__, 'tests/data/test.hdf5')
    >>> table = read_samples(filename)
    >>> table.colnames
    ['uvw', 'opq', 'lmn', 'ijk', 'def', 'abc', 'ghi', 'rst']

    """
    with h5py.File(filename, 'r') as f:
        if path is not None:  # Look for a given path
            table = f[path]
        else:  # Look for a given table name
            table = _find_table(f, tablename)
        table = Table.read(table)

    # Restore vary types.
    for i, column in enumerate(table.columns.values()):
        column.meta['vary'] = table.meta.get(
            'FIELD_{0}_VARY'.format(i), OUTPUT)

    # Restore fixed columns from table attributes.
    for key, value in table.meta.items():
        # Skip attributes from H5TB interface
        # (https://www.hdfgroup.org/HDF5/doc/HL/H5TB_Spec.html).
        if key == 'CLASS' or key == 'VERSION' or key == 'TITLE' or \
                key.startswith('FIELD_'):
            continue
        table.add_column(Column([value] * len(table), name=key,
                         meta={'vary': FIXED}))

    # Delete remaining table attributes.
    table.meta.clear()

    # Normalize column names.
    _remap_colnames(table)

    # Done!
    return table


def write_samples(table, filename, metadata=None, **kwargs):
    """Write an HDF5 sample chain file.

    Parameters
    ----------
    table : `astropy.table.Table`
        The sample chain as an Astropy table.
    filename : str
        The path of the HDF5 file on the filesystem.
    metadata: dict (optional)
        Dictionary of (path, value) pairs of metadata attributes
        to add to the output file
    kwargs: dict
        Any keyword arguments for `astropy.table.Table.write`.

    Examples
    --------
    Check that we catch columns that are supposed to be FIXED but are not:

    >>> table = Table([
    ...     Column(np.arange(10), name='foo', meta={'vary': FIXED})
    ... ])
    >>> write_samples(table, 'bar.hdf5', 'bat/baz')
    Traceback (most recent call last):
        ...
    AssertionError: 
    Arrays are not equal
    Column foo is a fixed column, but its values are not identical
    ...

    And now try writing an arbitrary example to a temporary file.

    >>> import os.path
    >>> import tempfile
    >>> table = Table([
    ...     Column(np.ones(10), name='foo', meta={'vary': FIXED}),
    ...     Column(np.arange(10), name='bar', meta={'vary': LINEAR}),
    ...     Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}),
    ...     Column(np.arange(10), name='baz', meta={'vary': OUTPUT}),
    ...     Column(np.ones(10), name='plugh'),
    ...     Column(np.arange(10), name='xyzzy')
    ... ])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     write_samples(
    ...         table, os.path.join(dir, 'test.hdf5'), path='bat/baz',
    ...         metadata={'bat/baz': {'widget': 'shoephone'}})

    """  # noqa: W291
    # Copy the table so that we do not modify the original.
    table = table.copy()

    # Make sure that all tables have a 'vary' type.
    for column in table.columns.values():
        if 'vary' not in column.meta:
            if np.all(column[0] == column[1:]):
                column.meta['vary'] = FIXED
            else:
                column.meta['vary'] = OUTPUT
    # Reconstruct table attributes.
    for colname, column in tuple(table.columns.items()):
        if column.meta['vary'] == FIXED:
            np.testing.assert_array_equal(column[1:], column[0],
                                          'Column {0} is a fixed column, but '
                                          'its values are not identical'
                                          .format(column.name))
            table.meta[colname] = column[0]
            del table[colname]
    for i, column in enumerate(table.columns.values()):
        table.meta['FIELD_{0}_VARY'.format(i)] = column.meta.pop('vary')
    table.write(filename, format='hdf5', **kwargs)
    if metadata:
        with h5py.File(filename, 'r+') as hdf:
            for internal_path, attributes in metadata.items():
                for key, value in attributes.items():
                    try:
                        hdf[internal_path].attrs[key] = value
                    except KeyError:
                        raise KeyError(
                            'Unable to set metadata {0}[{1}] = {2}'.format(
                                internal_path, key, value))

ligo/skymap/io/events/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/io/events/base.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Base classes for reading events from search pipelines."""

from abc import ABCMeta, abstractmethod
from collections.abc import Mapping

__all__ = ('EventSource', 'Event', 'SingleEvent')


def _fmt(obj, keys):
    kvs = ', '.join('{}={!r}'.format(key, getattr(obj, key)) for key in keys)
    return '<{}({})>'.format(obj.__class__.__name__, kvs)


class EventSource(Mapping):
    """Abstraction of a source of coincident events.

    This is a mapping from event IDs (which may be any hashable type, but are
    generally integers or strings) to instances of `Event`.
    """

    def __str__(self):
        try:
            length = len(self)
        except (NotImplementedError, TypeError):
            contents = '...'
        else:
            contents = '...{} items...'.format(length)
        return '<{}({{{}}})>'.format(self.__class__.__name__, contents)

    def __repr__(self):
        try:
            len(self)
        except NotImplementedError:
            contents = '...'
        else:
            contents = ', '.join('{}: {!r}'.format(key, value)
                                 for key, value in self.items())
        return '{}({{{}}})'.format(self.__class__.__name__, contents)


class Event(metaclass=ABCMeta):
    """Abstraction of a coincident trigger.

    Attributes
    ----------
    singles : list, tuple
        Sequence of `SingleEvent`
    template_args : dict
        Dictionary of template parameters

    """

    @property
    @abstractmethod
    def singles(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def template_args(self):
        raise NotImplementedError

    __str_keys = ('singles',)

    def __str__(self):
        return _fmt(self, self.__str_keys)

    __repr__ = __str__


class SingleEvent(metaclass=ABCMeta):
    """Abstraction of a single-detector trigger.

    Attributes
    ----------
    detector : str
        Instrument name (e.g. 'H1')
    snr : float
        Signal to noise ratio
    phase : float
        Phase on arrival
    time : float
        GPS time on arrival
    zerolag_time : float
        GPS time on arrival in zero-lag data, without time slides applied
    psd : `REAL8FrequencySeries`
        Power spectral density
    snr_series : `COMPLEX8TimeSeries`
        SNR time series

    """

    @property
    @abstractmethod
    def detector(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def snr(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def phase(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def time(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def zerolag_time(self):
        raise NotImplementedError

    @property
    @abstractmethod
    def psd(self):
        raise NotImplementedError

    @property
    def snr_series(self):
        return None

    __str_keys = ('detector', 'snr', 'phase', 'time')

    def __str__(self):
        keys = self.__str_keys
        if self.time != self.zerolag_time:
            keys += ('zerolag_time',)
        return _fmt(self, keys)

    __repr__ = __str__

ligo/skymap/io/events/detector_disabled.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Modify events by artificially disabling specified detectors."""
from .base import Event, EventSource

__all__ = ('DetectorDisabledEventSource', 'DetectorDisabledError')


class DetectorDisabledError(ValueError):
    pass


class DetectorDisabledEventSource(EventSource):

    def __init__(self, base_source, disabled_detectors, raises=True):
        self.base_source = base_source
        self.disabled_detectors = set(disabled_detectors)
        self.raises = raises

    def __iter__(self):
        return iter(self.base_source)

    def __getitem__(self, key):
        return DetectorDisabledEvent(self, self.base_source[key])

    def __len__(self):
        return len(self.base_source)


class DetectorDisabledEvent(Event):

    def __init__(self, source, base_event):
        self.source = source
        self.base_event = base_event

    @property
    def singles(self):
        disabled_detectors = self.source.disabled_detectors
        if self.source.raises:
            detectors = {s.detector for s in self.base_event.singles}
            if not detectors & disabled_detectors:
                raise DetectorDisabledError(
                    'Disabling detectors {{{}}} would have no effect on this '
                    'event with detectors {{{}}}'.format(
                        ' '.join(sorted(disabled_detectors)),
                        ' '.join(sorted(detectors))))
            if not detectors - disabled_detectors:
                raise DetectorDisabledError(
                    'Disabling detectors {{{}}} would exclude all data for '
                    'this event with detectors {{{}}}'.format(
                        ' '.join(sorted(disabled_detectors)),
                        ' '.join(sorted(detectors))))
        return tuple(s for s in self.base_event.singles
                     if s.detector not in disabled_detectors)

    @property
    def template_args(self):
        return self.base_event.template_args


open = DetectorDisabledEventSource

ligo/skymap/io/events/gracedb.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
from ligo.gracedb import rest

from .base import EventSource
from .ligolw import LigoLWEventSource

__all__ = ('GraceDBEventSource',)


class GraceDBEventSource(EventSource):
    """Read events from GraceDB.

    Parameters
    ----------
    graceids : list
        List of GraceDB ID strings.
    client : `ligo.gracedb.rest.GraceDb`, optional
        Client object

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, graceids, client=None):
        if client is None:
            client = rest.GraceDb()
        self._client = client
        self._graceids = graceids

    def __iter__(self):
        return iter(self._graceids)

    def __getitem__(self, graceid):
        coinc_file = self._client.files(graceid, 'coinc.xml')
        psd_file = self._client.files(graceid, 'psd.xml.gz')
        event, = LigoLWEventSource(
            coinc_file, psd_file=psd_file, coinc_def=None).values()
        return event

    def __len__(self):
        return len(self._graceids)


open = GraceDBEventSource

ligo/skymap/io/events/hdf.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from PyCBC-style HDF5 output."""
from operator import itemgetter
from itertools import groupby

import h5py
import numpy as np
import lal
from ligo.segments import segment, segmentlist

from .base import Event, EventSource, SingleEvent

__all__ = ('HDFEventSource',)


class _psd_segment(segment):  # noqa: N801

    def __new__(cls, psd, *args):
        return segment.__new__(cls, *args)

    def __init__(self, psd, *args):
        self.psd = psd


def _hdf_file(f):
    if isinstance(f, h5py.File):
        return f
    elif hasattr(f, 'read') and hasattr(f, 'name'):
        return h5py.File(f.name, 'r')
    else:
        return h5py.File(f, 'r')


def _classify_hdf_file(f, sample):
    if sample in f:
        return 'coincs'
    for key, value in f.items():
        if isinstance(value, h5py.Group):
            if 'psds' in value:
                return 'psds'
            if 'snr' in value and 'coa_phase' in value and 'end_time' in value:
                return 'triggers'
    if 'parameters' in f.attrs:
        return 'bank'
    raise ValueError('Unrecognized PyCBC file type')


class HDFEventSource(EventSource):
    """Read events from PyCBC-style HDF5 files.

    Parameters
    ----------
    *files : list of str, file-like object, or `h5py.File` objects
        The PyCBC coinc, bank, psds, and triggers files, in any order.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, *files, **kwargs):
        sample = kwargs.get('sample', 'foreground')

        # Open the files and split them into coinc files, bank files, psds,
        # and triggers.
        key = itemgetter(0)
        files = [_hdf_file(f) for f in files]
        files = sorted(
            [(_classify_hdf_file(f, sample), f) for f in files], key=key)
        files = {key: list(v[1] for v in value)
                 for key, value in groupby(files, key)}

        try:
            coinc_file, = files['coincs']
        except (KeyError, ValueError):
            raise ValueError('You must provide exactly one coinc file.')
        try:
            bank_file, = files['bank']
        except (KeyError, ValueError):
            raise ValueError(
                'You must provide exactly one template bank file.')
        try:
            psd_files = files['psds']
        except KeyError:
            raise ValueError('You must provide PSD files.')
        try:
            trigger_files = files['triggers']
        except KeyError:
            raise ValueError('You must provide trigger files.')

        self._bank = bank_file

        key_prefix = 'detector_'
        detector_nums, self._ifos = zip(*sorted(
            (int(key[len(key_prefix):]), value)
            for key, value in coinc_file.attrs.items()
            if key.startswith(key_prefix)))

        coinc_group = coinc_file[sample]
        self._timeslide_interval = coinc_file.attrs.get(
            'timeslide_interval', 0)
        self._template_ids = coinc_group['template_id']
        self._timeslide_ids = coinc_group.get(
            'timeslide_id', np.zeros(len(self)))
        self._trigger_ids = [
            coinc_group['trigger_id{}'.format(detector_num)]
            for detector_num in detector_nums]

        triggers = {}
        for f in trigger_files:
            (ifo, group), = f.items()
            triggers[ifo] = [
                group['snr'], group['coa_phase'], group['end_time']]
        self._triggers = tuple(triggers[ifo] for ifo in self._ifos)

        psdseglistdict = {}
        for psd_file in psd_files:
            (ifo, group), = psd_file.items()
            psd = [group['psds'][str(i)] for i in range(len(group['psds']))]
            psdseglistdict[ifo] = segmentlist(
                _psd_segment(*segargs) for segargs in zip(
                    psd, group['start_time'], group['end_time']))
        self._psds = [psdseglistdict[ifo] for ifo in self._ifos]

    def __getitem__(self, id):
        return HDFEvent(self, id)

    def __iter__(self):
        return iter(range(len(self)))

    def __len__(self):
        return len(self._template_ids)


class HDFEvent(Event):

    def __init__(self, source, id):
        self._source = source
        self._id = id

    @property
    def singles(self):
        return tuple(
            HDFSingleEvent(
                ifo, self._id, i, trigger_ids[self._id],
                self._source._timeslide_interval, triggers,
                self._source._timeslide_ids, psds
            )
            for i, (ifo, trigger_ids, triggers, psds) in enumerate(zip(
                self._source._ifos, self._source._trigger_ids,
                self._source._triggers, self._source._psds
            ))
        )

    @property
    def template_args(self):
        bank = self._source._bank
        bank_id = self._source._template_ids[self._id]
        return {key: value[bank_id] for key, value in bank.items()}


class HDFSingleEvent(SingleEvent):

    def __init__(
            self, detector, _coinc_id, _detector_num, _trigger_id,
            _timeslide_interval, _triggers, _timeslide_ids, _psds):
        self._detector = detector
        self._coinc_id = _coinc_id
        self._detector_num = _detector_num
        self._trigger_id = _trigger_id
        self._timeslide_interval = _timeslide_interval
        self._triggers = _triggers
        self._timeslide_ids = _timeslide_ids
        self._psds = _psds

    @property
    def detector(self):
        return self._detector

    @property
    def snr(self):
        return self._triggers[0][self._trigger_id]

    @property
    def phase(self):
        return self._triggers[1][self._trigger_id]

    @property
    def time(self):
        value = self.zerolag_time

        # PyCBC does not specify which detector is time-shifted in time slides.
        # Since PyCBC's coincidence format currently supports only two
        # detectors, we arbitrarily apply half of the time slide to each
        # detector.
        shift = self._timeslide_ids[self._coinc_id] * self._timeslide_interval
        if self._detector_num == 0:
            value -= 0.5 * shift
        elif self._detector_num == 1:
            value += 0.5 * shift
        else:
            raise AssertionError('This line should not be reached')
        return value

    @property
    def zerolag_time(self):
        return self._triggers[2][self._trigger_id]

    @property
    def psd(self):
        try:
            psd = self._psds[self._psds.find(self.zerolag_time)].psd
        except ValueError:
            raise ValueError(
                'No PSD found for detector {} at zero-lag GPS time {}'.format(
                    self.detector, self.zerolag_time))

        dyn_range_fac = psd.file.attrs['dynamic_range_factor']
        flow = psd.file.attrs['low_frequency_cutoff']
        df = psd.attrs['delta_f']
        kmin = int(flow / df)

        fseries = lal.CreateREAL8FrequencySeries(
            'psd', 0, kmin * df, df,
            lal.DimensionlessUnit, len(psd) - kmin)
        fseries.data.data = psd[kmin:] / np.square(dyn_range_fac)
        return fseries


open = HDFEventSource

ligo/skymap/io/events/ligolw.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from pipedown/GstLal-style XML output."""
from collections import OrderedDict, defaultdict
import errno
from functools import lru_cache
import itertools
import logging
import operator
import os.path

from ligo.lw import array, lsctables, param, table
from ligo.lw.ligolw import Element, LIGOLWContentHandler, LIGO_LW
from ligo.lw.lsctables import (
    CoincDefTable, CoincMapTable, CoincTable, ProcessTable, ProcessParamsTable,
    SnglInspiralTable, TimeSlideID, TimeSlideTable)
from ligo.lw.utils import load_filename, load_fileobj
import lal
import lal.series
from lalinspiral.thinca import InspiralCoincDef

from .base import Event, EventSource, SingleEvent
from ...util import ilwd

__all__ = ('LigoLWEventSource',)

log = logging.getLogger('BAYESTAR')


@ilwd.use_in
@array.use_in
@lsctables.use_in
@param.use_in
@table.use_in
class ContentHandler(LIGOLWContentHandler):
    pass


def _read_xml(f, fallbackpath=None):
    if f is None:
        doc = filename = None
    elif isinstance(f, Element):
        doc = f
        filename = ''
    elif isinstance(f, str):
        try:
            doc = load_filename(f, contenthandler=ContentHandler)
        except IOError as e:
            if e.errno == errno.ENOENT and fallbackpath and \
                    not os.path.isabs(f):
                f = os.path.join(fallbackpath, f)
                doc = load_filename(f, contenthandler=ContentHandler)
            else:
                raise
        filename = f
    else:
        doc = load_fileobj(f, contenthandler=ContentHandler)
        try:
            filename = f.name
        except AttributeError:
            filename = ''
    return doc, filename


class LigoLWEventSource(OrderedDict, EventSource):
    """Read events from LIGO-LW XML files.

    Parameters
    ----------
    f : str, file-like object, or `ligo.lw.ligolw.Document`
        The name of the file, or the file object, or the XML document object,
        containing the trigger tables.
    psd_file : str, file-like object, or `ligo.lw.ligolw.Document`
        The name of the file, or the file object, or the XML document object,
        containing the PSDs. If not supplied, then PSDs will be read
        automatically from any files mentioned in ``--reference-psd`` command
        line arguments stored in the ProcessParams table.
    coinc_def : `ligo.lw.lsctables.CoincDef`, optional
        An optional coinc definer to limit which events are read.
    fallbackpath : str, optional
        A directory to search for PSD files whose ``--reference-psd`` command
        line arguments have relative paths. By default, the current working
        directory and the directory containing file ``f`` are searched.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, f, psd_file=None, coinc_def=InspiralCoincDef,
                 fallbackpath=None, **kwargs):
        doc, filename = _read_xml(f)
        self._fallbackpath = (
            os.path.dirname(filename) if filename else fallbackpath)
        self._psds_for_file = lru_cache(maxsize=None)(self._psds_for_file)
        super().__init__(self._make_events(doc, psd_file, coinc_def))

    _template_keys = '''mass1 mass2
                        spin1x spin1y spin1z spin2x spin2y spin2z
                        f_final'''.split()

    _invert_phases = {
        'pycbc': False,
        'gstlal_inspiral': True,
        'gstlal_inspiral_postcohspiir_online': True,  # FIXME: wild guess
        'bayestar_realize_coincs': True,
        'bayestar-realize-coincs': True,
        'MBTAOnline': True
    }

    @classmethod
    def _phase_convention(cls, program):
        try:
            return cls._invert_phases[program]
        except KeyError:
            raise KeyError(
                ('The pipeline "{}" is unknown, so the phase '
                 'convention could not be deduced.').format(program))

    def _psds_for_file(self, f):
        doc, _ = _read_xml(f, self._fallbackpath)
        return lal.series.read_psd_xmldoc(doc, root_name=None)

    def _make_events(self, doc, psd_file, coinc_def):
        # Look up necessary tables.
        coinc_table = CoincTable.get_table(doc)
        coinc_map_table = CoincMapTable.get_table(doc)
        sngl_inspiral_table = SnglInspiralTable.get_table(doc)
        try:
            time_slide_table = TimeSlideTable.get_table(doc)
        except ValueError:
            offsets_by_time_slide_id = None
        else:
            offsets_by_time_slide_id = time_slide_table.as_dict()

        # Indices to speed up lookups by ID.
        key = operator.attrgetter('coinc_event_id')
        event_ids_by_coinc_event_id = {
            coinc_event_id:
                tuple(coinc_map.event_id for coinc_map in coinc_maps
                      if coinc_map.table_name == SnglInspiralTable.tableName)
            for coinc_event_id, coinc_maps
            in itertools.groupby(sorted(coinc_map_table, key=key), key=key)}
        sngl_inspirals_by_event_id = {
            row.event_id: row for row in sngl_inspiral_table}

        # Filter rows by coinc_def if requested.
        if coinc_def is not None:
            coinc_def_table = CoincDefTable.get_table(doc)
            coinc_def_ids = {
                row.coinc_def_id for row in coinc_def_table
                if (row.search, row.search_coinc_type) ==
                (coinc_def.search, coinc_def.search_coinc_type)}
            coinc_table = [
                row for row in coinc_table
                if row.coinc_def_id in coinc_def_ids]

        snr_dict = dict(self._snr_series_by_sngl_inspiral(doc))

        process_table = ProcessTable.get_table(doc)
        program_for_process_id = {
            row.process_id: row.program for row in process_table}

        try:
            process_params_table = ProcessParamsTable.get_table(doc)
        except ValueError:
            psd_filenames_by_process_id = {}
        else:
            psd_filenames_by_process_id = {
                process_param.process_id: process_param.value
                for process_param in process_params_table
                if process_param.param == '--reference-psd'}

        ts0 = TimeSlideID(0)
        for time_slide_id in {coinc.time_slide_id for coinc in coinc_table}:
            if offsets_by_time_slide_id is None and time_slide_id == ts0:
                log.warning(
                    'Time slide record is missing for %s, '
                    'guessing that this is zero-lag', time_slide_id)

        for program in {program_for_process_id[coinc.process_id]
                        for coinc in coinc_table}:
            invert_phases = self._phase_convention(program)
            if invert_phases:
                log.warning(
                    'Using anti-FINDCHIRP phase convention; inverting phases. '
                    'This is currently the default and it is appropriate for '
                    'gstlal and MBTA but not pycbc as of observing run 1 '
                    '("O1"). The default setting is likely to change in the '
                    'future.')

        for coinc in coinc_table:
            coinc_event_id = coinc.coinc_event_id
            coinc_event_num = int(coinc_event_id)
            sngls = [sngl_inspirals_by_event_id[event_id] for event_id
                     in event_ids_by_coinc_event_id[coinc_event_id]]
            if offsets_by_time_slide_id is None and coinc.time_slide_id == ts0:
                offsets = defaultdict(float)
            else:
                offsets = offsets_by_time_slide_id[coinc.time_slide_id]

            template_args = [
                {key: getattr(sngl, key) for key in self._template_keys}
                for sngl in sngls]
            if any(d != template_args[0] for d in template_args[1:]):
                raise ValueError(
                    'Template arguments are not identical for all detectors!')
            template_args = template_args[0]

            invert_phases = self._phase_convention(
                program_for_process_id[coinc.process_id])

            singles = tuple(LigoLWSingleEvent(
                self, sngl.ifo, sngl.snr, sngl.coa_phase,
                float(sngl.end + offsets[sngl.ifo]), float(sngl.end),
                psd_file or psd_filenames_by_process_id.get(sngl.process_id),
                snr_dict.get(sngl.event_id), invert_phases)
                for sngl in sngls)

            event = LigoLWEvent(coinc_event_num, singles, template_args)

            yield coinc_event_num, event

    @classmethod
    def _snr_series_by_sngl_inspiral(cls, doc):
        for elem in doc.getElementsByTagName(LIGO_LW.tagName):
            try:
                if elem.Name != lal.COMPLEX8TimeSeries.__name__:
                    continue
                array.get_array(elem, 'snr')
                event_id = param.get_pyvalue(elem, 'event_id')
            except (AttributeError, ValueError):
                continue
            else:
                yield event_id, lal.series.parse_COMPLEX8TimeSeries(elem)


class LigoLWEvent(Event):

    def __init__(self, id, singles, template_args):
        self._id = id
        self._singles = singles
        self._template_args = template_args

    @property
    def singles(self):
        return self._singles

    @property
    def template_args(self):
        return self._template_args


class LigoLWSingleEvent(SingleEvent):

    def __init__(self, source, detector, snr, phase, time, zerolag_time,
                 psd_file, snr_series, invert_phases):
        self._source = source
        self._detector = detector
        self._snr = snr
        self._phase = phase
        self._time = time
        self._zerolag_time = zerolag_time
        self._psd_file = psd_file
        self._snr_series = snr_series
        self._invert_phases = invert_phases

    @property
    def detector(self):
        return self._detector

    @property
    def snr(self):
        return self._snr

    @property
    def phase(self):
        value = self._phase
        if value is not None and self._invert_phases:
            value *= -1
        return value

    @property
    def time(self):
        return self._time

    @property
    def zerolag_time(self):
        return self._zerolag_time

    @property
    def psd(self):
        return self._source._psds_for_file(self._psd_file)[self._detector]

    @property
    def snr_series(self):
        value = self._snr_series
        if self._invert_phases and value is not None:
            value = lal.CutCOMPLEX8TimeSeries(value, 0, len(value.data.data))
            value.data.data = value.data.data.conj()
        return value


open = LigoLWEventSource

ligo/skymap/io/events/magic.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from either HDF or LIGO-LW files."""
import os
import sqlite3
from subprocess import check_output

from ligo.lw.ligolw import Element
import h5py

from . import hdf, ligolw, sqlite

__all__ = ('MagicEventSource', 'open')


def _get_file_type(f):
    """Determine the file type by calling the POSIX ``file`` utility.

    Parameters
    ----------
    f : file, str
        A file object or the path to a file

    Returns
    -------
    filetype : bytes
        A string describing the file type

    """
    try:
        f.read
    except AttributeError:
        filetype = check_output(
            ['file', f], env=dict(os.environ, POSIXLY_CORRECT='1'))
    else:
        filetype = check_output(
            ['file', '-'], env=dict(os.environ, POSIXLY_CORRECT='1'), stdin=f)
        f.seek(0)
    _, _, filetype = filetype.partition(b': ')
    return filetype.strip()


def MagicEventSource(f, *args, **kwargs):  # noqa: N802
    """Read events from LIGO-LW XML, LIGO-LW SQlite, or HDF5 files. The format
    is determined automatically using the :manpage:`file(1)` command, and then
    the file is opened using :obj:`.ligolw.open`, :obj:`.sqlite.open`, or
    :obj:`.hdf.open`, as appropriate.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """
    if isinstance(f, h5py.File):
        opener = hdf.open
    elif isinstance(f, sqlite3.Connection):
        opener = sqlite.open
    elif isinstance(f, Element):
        opener = ligolw.open
    else:
        filetype = _get_file_type(f)
        if filetype == b'Hierarchical Data Format (version 5) data':
            opener = hdf.open
        elif filetype.startswith(b'SQLite 3.x database'):
            opener = sqlite.open
        elif filetype.startswith(b'XML') or filetype.startswith(b'gzip'):
            opener = ligolw.open
        else:
            raise IOError('Unknown file format')
    return opener(f, *args, **kwargs)


open = MagicEventSource

ligo/skymap/io/events/sqlite.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read events from a GstLal-style SQLite output."""
import os
import sqlite3

from ligo.lw import dbtables

from ...util import sqlite
from .ligolw import LigoLWEventSource

__all__ = ('SQLiteEventSource',)


class SQLiteEventSource(LigoLWEventSource):
    """Read events from LIGO-LW SQLite files.

    Parameters
    ----------
    f : str, file-like object, or `sqlite3.Connection` instance
        The SQLite database.

    Returns
    -------
    `~ligo.skymap.io.events.EventSource`

    """

    def __init__(self, f, *args, **kwargs):
        if isinstance(f, sqlite3.Connection):
            db = f
            filename = sqlite.get_filename(f)
        else:
            if hasattr(f, 'read'):
                filename = f.name
                f.close()
            else:
                filename = f
            db = sqlite.open(filename, 'r')
        super().__init__(dbtables.get_xml(db), *args, **kwargs)
        self._fallbackpath = os.path.dirname(filename) if filename else None


open = SQLiteEventSource

ligo/skymap/plot/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/plot/allsky.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
469  
470  
471  
472  
473  
474  
475  
476  
477  
478  
479  
480  
481  
482  
483  
484  
485  
486  
487  
488  
489  
490  
491  
492  
493  
494  
495  
496  
497  
498  
499  
500  
501  
502  
503  
504  
505  
506  
507  
508  
509  
510  
511  
512  
513  
514  
515  
516  
517  
518  
519  
520  
521  
522  
523  
524  
525  
526  
527  
528  
529  
530  
531  
532  
533  
534  
535  
536  
537  
538  
539  
540  
541  
542  
543  
544  
545  
546  
547  
548  
549  
550  
551  
552  
553  
554  
555  
556  
557  
558  
559  
560  
561  
562  
563  
564  
565  
566  
567  
568  
569  
570  
571  
572  
573  
574  
575  
576  
577  
578  
579  
580  
581  
582  
583  
584  
585  
586  
587  
588  
589  
590  
591  
592  
593  
594  
595  
596  
597  
598  
599  
600  
601  
602  
603  
604  
605  
606  
607  
608  
609  
610  
611  
612  
613  
614  
615  
616  
617  
618  
619  
620  
621  
622  
623  
624  
625  
626  
627  
628  
629  
630  
631  
632  
633  
634  
635  
636  
637  
638  
639  
640  
641  
642  
643  
644  
645  
646  
647  
648  
649  
650  
651  
652  
653  
654  
655  
656  
657  
658  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Axes subclasses for astronomical mapmaking.

This module adds several :class:`astropy.visualization.wcsaxes.WCSAxes`
subclasses to the Matplotlib projection registry. The projections have names of
the form :samp:`{astro_or_geo} [{lon_units}] {projection}`.

:samp:`{astro_or_geo}` may be ``astro`` or ``geo``. It controls the
reference frame, either celestial (ICRS) or terrestrial (ITRS).

:samp:`{lon_units}` may be ``hours`` or ``degrees``. It controls the units of
the longitude axis. If omitted, ``astro`` implies ``hours`` and ``geo`` implies
degrees.

:samp:`{projection}` may be any of the following:

* ``aitoff`` for the Aitoff all-sky projection

* ``mollweide`` for the Mollweide all-sky projection

* ``globe`` for an orthographic projection, like the three-dimensional view of
  the Earth from a distant satellite

* ``zoom`` for a gnomonic projection suitable for visualizing small zoomed-in
  patches

Some of the projections support additional optional arguments. The ``globe``
projections support the options ``center`` and ``rotate``. The ``zoom``
projections support the options ``center``, ``radius``, and ``rotate``.

Examples
--------
.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='astro hours mollweide')
    ax.grid()

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='geo aitoff')
    ax.grid()

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='astro zoom',
                  center='5h -32d', radius='5 deg', rotate='20 deg')
    ax.grid()

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='geo globe', center='-50d +23d')
    ax.grid()

Complete Example
----------------
The following example demonstrates most of the features of this module.

.. plot::
   :context: reset
   :include-source:
   :align: center

    from astropy.coordinates import SkyCoord
    from astropy.io import fits
    from astropy import units as u
    import ligo.skymap.plot
    from matplotlib import pyplot as plt

    url = 'https://dcc.ligo.org/public/0146/G1701985/001/bayestar_no_virgo.fits.gz'
    center = SkyCoord.from_name('NGC 4993')

    fig = plt.figure(figsize=(4, 4), dpi=100)

    ax = plt.axes(
        [0.05, 0.05, 0.9, 0.9],
        projection='astro globe',
        center=center)

    ax_inset = plt.axes(
        [0.59, 0.3, 0.4, 0.4],
        projection='astro zoom',
        center=center,
        radius=10*u.deg)

    for key in ['ra', 'dec']:
        ax_inset.coords[key].set_ticklabel_visible(False)
        ax_inset.coords[key].set_ticks_visible(False)
    ax.grid()
    ax.mark_inset_axes(ax_inset)
    ax.connect_inset_axes(ax_inset, 'upper left')
    ax.connect_inset_axes(ax_inset, 'lower left')
    ax_inset.scalebar((0.1, 0.1), 5 * u.deg).label()
    ax_inset.compass(0.9, 0.1, 0.2)

    ax.imshow_hpx(url, cmap='cylon')
    ax_inset.imshow_hpx(url, cmap='cylon')
    ax_inset.plot(
        center.ra.deg, center.dec.deg,
        transform=ax_inset.get_transform('world'),
        marker=ligo.skymap.plot.reticle(),
        markersize=30,
        markeredgewidth=3)

"""  # noqa: E501
from itertools import product

from astropy.coordinates import SkyCoord
from astropy.io.fits import Header
from astropy.time import Time
from astropy.visualization.wcsaxes import WCSAxes
from astropy.visualization.wcsaxes.formatter_locator import (
    AngleFormatterLocator)
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from astropy.wcs import WCS
from astropy import units as u
from matplotlib import rcParams
from matplotlib.offsetbox import AnchoredOffsetbox
from matplotlib.patches import ConnectionPatch, FancyArrowPatch, PathPatch
from matplotlib.projections import projection_registry
import numpy as np
from reproject import reproject_from_healpix
from scipy.optimize import minimize_scalar
from .angle import reference_angle_deg
from . import itrs_frame_monkeypatch

itrs_frame_monkeypatch.install()

__all__ = ['AutoScaledWCSAxes', 'ScaleBar']


class WCSInsetPatch(PathPatch):
    """Subclass of `matplotlib.patches.PathPatch` for marking the outline of
    one `astropy.visualization.wcsaxes.WCSAxes` inside another.
    """

    def __init__(self, ax, *args, **kwargs):
        self._ax = ax
        super().__init__(
            None, *args, fill=False,
            edgecolor=ax.coords.frame.get_color(),
            linewidth=ax.coords.frame.get_linewidth(),
            **kwargs)

    def get_path(self):
        frame = self._ax.coords.frame
        return frame.patch.get_path().interpolated(50).transformed(
            frame.transform)


class WCSInsetConnectionPatch(ConnectionPatch):
    """Patch to connect an inset WCS axes inside another WCS axes."""

    _corners_map = {1: 3, 2: 1, 3: 0, 4: 2}

    def __init__(self, ax, ax_inset, loc, *args, **kwargs):
        try:
            loc = AnchoredOffsetbox.codes[loc]
        except KeyError:
            loc = int(loc)
        corners = ax_inset.viewLim.corners()
        transform = (ax_inset.coords.frame.transform +
                     ax.coords.frame.transform.inverted())
        xy_inset = corners[self._corners_map[loc]]
        xy = transform.transform_point(xy_inset)
        super().__init__(
            xy, xy_inset, 'data', 'data', ax, ax_inset, *args,
            color=ax_inset.coords.frame.get_color(),
            linewidth=ax_inset.coords.frame.get_linewidth(),
            **kwargs)


class AutoScaledWCSAxes(WCSAxes):
    """Axes base class. The pixel scale is adjusted to the DPI of the image,
    and there are a variety of convenience methods.
    """

    name = 'astro wcs'

    def __init__(self, *args, header, obstime=None, **kwargs):
        super().__init__(*args, aspect=1, **kwargs)
        h = Header(header, copy=True)
        naxis1 = h['NAXIS1']
        naxis2 = h['NAXIS2']
        scale = min(self.bbox.width / naxis1, self.bbox.height / naxis2)
        h['NAXIS1'] = int(np.ceil(naxis1 * scale))
        h['NAXIS2'] = int(np.ceil(naxis2 * scale))
        scale1 = h['NAXIS1'] / naxis1
        scale2 = h['NAXIS2'] / naxis2
        h['CRPIX1'] = (h['CRPIX1'] - 1) * (h['NAXIS1'] - 1) / (naxis1 - 1) + 1
        h['CRPIX2'] = (h['CRPIX2'] - 1) * (h['NAXIS2'] - 1) / (naxis2 - 1) + 1
        h['CDELT1'] /= scale1
        h['CDELT2'] /= scale2
        if obstime is not None:
            h['DATE-OBS'] = Time(obstime).utc.isot
        self.reset_wcs(WCS(h))
        self.set_xlim(-0.5, h['NAXIS1'] - 0.5)
        self.set_ylim(-0.5, h['NAXIS2'] - 0.5)
        self._header = h

    @property
    def header(self):
        return self._header

    def mark_inset_axes(self, ax, *args, **kwargs):
        """Outline the footprint of another WCSAxes inside this one.

        Parameters
        ----------
        ax : `astropy.visualization.wcsaxes.WCSAxes`
            The other axes.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.PathPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.PathPatch`

        Returns
        -------
        patch : `matplotlib.patches.PathPatch`

        """
        return self.add_patch(WCSInsetPatch(
            ax, *args, transform=self.get_transform('world'), **kwargs))

    def connect_inset_axes(self, ax, loc, *args, **kwargs):
        """Connect a corner of another WCSAxes to the matching point inside
        this one.

        Parameters
        ----------
        ax : `astropy.visualization.wcsaxes.WCSAxes`
            The other axes.
        loc : int, str
            Which corner to connect. For valid values, see
            `matplotlib.offsetbox.AnchoredOffsetbox`.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.ConnectionPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.ConnectionPatch`

        Returns
        -------
        patch : `matplotlib.patches.ConnectionPatch`

        """
        return self.add_patch(WCSInsetConnectionPatch(
            self, ax, loc, *args, **kwargs))

    def compass(self, x, y, size):
        """Add a compass to indicate the north and east directions.

        Parameters
        ----------
        x, y : float
            Position of compass vertex in axes coordinates.
        size : float
            Size of compass in axes coordinates.

        """
        xy = x, y
        scale = self.wcs.pixel_scale_matrix
        scale /= np.sqrt(np.abs(np.linalg.det(scale)))
        return [self.annotate(label, xy, xy + size * n,
                              self.transAxes, self.transAxes,
                              ha='center', va='center',
                              arrowprops=dict(arrowstyle='<-',
                                              shrinkA=0.0, shrinkB=0.0))
                for n, label, ha, va in zip(scale, 'EN',
                                            ['right', 'center'],
                                            ['center', 'bottom'])]

    def scalebar(self, *args, **kwargs):
        """Add scale bar.

        Parameters
        ----------
        xy : tuple
            The axes coordinates of the scale bar.
        length : `astropy.units.Quantity`
            The length of the scale bar in angle-compatible units.

        Other parameters
        ----------------
        args :
            Extra arguments for `matplotlib.patches.FancyArrowPatch`
        kwargs :
            Extra keyword arguments for `matplotlib.patches.FancyArrowPatch`

        Returns
        -------
        patch : `matplotlib.patches.FancyArrowPatch`

        """
        return self.add_patch(ScaleBar(self, *args, **kwargs))

    def _reproject_hpx(self, data, hdu_in=None, order='bilinear',
                       nested=False, field=0, smooth=None):
        if isinstance(data, np.ndarray):
            data = (data, self.header['RADESYS'])

        # It's normal for reproject_from_healpix to produce some Numpy invalid
        # value warnings for points that land outside the projection.
        with np.errstate(invalid='ignore'):
            img, mask = reproject_from_healpix(
                data, self.header, hdu_in=hdu_in, order=order, nested=nested,
                field=field)
        img = np.ma.array(img, mask=~mask.astype(bool))

        if smooth is not None:
            # Infrequently used imports
            from astropy.convolution import convolve_fft, Gaussian2DKernel

            pixsize = np.mean(np.abs(self.wcs.wcs.cdelt)) * u.deg
            smooth = (smooth / pixsize).to(u.dimensionless_unscaled).value
            kernel = Gaussian2DKernel(smooth)
            # Ignore divide by zero warnings for pixels that have no valid
            # neighbors.
            with np.errstate(invalid='ignore'):
                img = convolve_fft(img, kernel, fill_value=np.nan)

        return img

    def contour_hpx(self, data, hdu_in=None, order='bilinear', nested=False,
                    field=0, smooth=None, **kwargs):
        """Add contour levels for a HEALPix data set.

        Parameters
        ----------
        data : `numpy.ndarray` or str or `~astropy.io.fits.TableHDU` or `~astropy.io.fits.BinTableHDU` or tuple
            The HEALPix data set. If this is a `numpy.ndarray`, then it is
            interpreted as the HEALPix array in the same coordinate system as
            the axes. Otherwise, the input data can be any type that is
            understood by `reproject.reproject_from_healpix`.
        smooth : `astropy.units.Quantity`, optional
            An optional smoothing length in angle-compatible units.

        Other parameters
        ----------------
        hdu_in, order, nested, field, smooth :
            Extra arguments for `reproject.reproject_from_healpix`
        kwargs :
            Extra keyword arguments for `matplotlib.axes.Axes.contour`

        Returns
        -------
        countours : `matplotlib.contour.QuadContourSet`

        """  # noqa: E501
        img = self._reproject_hpx(data, hdu_in=hdu_in, order=order,
                                  nested=nested, field=field, smooth=smooth)
        return self.contour(img, **kwargs)

    def contourf_hpx(self, data, hdu_in=None, order='bilinear', nested=False,
                     field=0, smooth=None, **kwargs):
        """Add filled contour levels for a HEALPix data set.

        Parameters
        ----------
        data : `numpy.ndarray` or str or `~astropy.io.fits.TableHDU` or `~astropy.io.fits.BinTableHDU` or tuple
            The HEALPix data set. If this is a `numpy.ndarray`, then it is
            interpreted as the HEALPix array in the same coordinate system as
            the axes. Otherwise, the input data can be any type that is
            understood by `reproject.reproject_from_healpix`.
        smooth : `astropy.units.Quantity`, optional
            An optional smoothing length in angle-compatible units.

        Other parameters
        ----------------
        hdu_in, order, nested, field, smooth :
            Extra arguments for `reproject.reproject_from_healpix`
        kwargs :
            Extra keyword arguments for `matplotlib.axes.Axes.contour`

        Returns
        -------
        contours : `matplotlib.contour.QuadContourSet`

        """  # noqa: E501
        img = self._reproject_hpx(data, hdu_in=hdu_in, order=order,
                                  nested=nested, field=field, smooth=smooth)
        return self.contourf(img, **kwargs)

    def imshow_hpx(self, data, hdu_in=None, order='bilinear', nested=False,
                   field=0, smooth=None, **kwargs):
        """Add an image for a HEALPix data set.

        Parameters
        ----------
        data : `numpy.ndarray` or str or `~astropy.io.fits.TableHDU` or `~astropy.io.fits.BinTableHDU` or tuple
            The HEALPix data set. If this is a `numpy.ndarray`, then it is
            interpreted as the HEALPix array in the same coordinate system as
            the axes. Otherwise, the input data can be any type that is
            understood by `reproject.reproject_from_healpix`.
        smooth : `astropy.units.Quantity`, optional
            An optional smoothing length in angle-compatible units.

        Other parameters
        ----------------
        hdu_in, order, nested, field, smooth :
            Extra arguments for `reproject.reproject_from_healpix`
        kwargs :
            Extra keyword arguments for `matplotlib.axes.Axes.contour`

        Returns
        -------
        image : `matplotlib.image.AxesImage`

        """  # noqa: E501
        img = self._reproject_hpx(data, hdu_in=hdu_in, order=order,
                                  nested=nested, field=field, smooth=smooth)
        return self.imshow(img, **kwargs)


class ScaleBar(FancyArrowPatch):

    def _func(self, dx, x, y):
        p1, p2 = self._transAxesToWorld.transform([[x, y], [x + dx, y]])
        p1 = SkyCoord(*p1, unit=u.deg)
        p2 = SkyCoord(*p2, unit=u.deg)
        return np.square((p1.separation(p2) - self._length).value)

    def __init__(self, ax, xy, length, *args, **kwargs):
        x, y = xy
        self._ax = ax
        self._length = u.Quantity(length)
        self._transAxesToWorld = (
            (ax.transAxes - ax.transData) + ax.coords.frame.transform)
        dx = minimize_scalar(
            self._func, args=xy, bounds=[0, 1 - x], method='bounded').x
        custom_kwargs = kwargs
        kwargs = dict(
            capstyle='round',
            color='black',
            linewidth=rcParams['lines.linewidth'],
        )
        kwargs.update(custom_kwargs)
        super().__init__(
            xy, (x + dx, y),
            *args,
            arrowstyle='-',
            shrinkA=0.0,
            shrinkB=0.0,
            transform=ax.transAxes,
            **kwargs)

    def label(self, **kwargs):
        (x0, y), (x1, _) = self._posA_posB
        s = ' {0.value:g}{0.unit:unicode}'.format(self._length)
        return self._ax.text(
            0.5 * (x0 + x1), y, s,
            ha='center', va='bottom', transform=self._ax.transAxes, **kwargs)


class Astro:
    _crval1 = 180
    _xcoord = 'RA--'
    _ycoord = 'DEC-'
    _radesys = 'ICRS'


class GeoAngleFormatterLocator(AngleFormatterLocator):

    def formatter(self, values, spacing):
        return super().formatter(
            reference_angle_deg(values.to(u.deg).value) * u.deg, spacing)


class Geo:
    _crval1 = 0
    _radesys = 'ITRS'
    _xcoord = 'TLON'
    _ycoord = 'TLAT'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.invert_xaxis()
        fl = self.coords[0]._formatter_locator
        self.coords[0]._formatter_locator = GeoAngleFormatterLocator(
            values=fl.values,
            number=fl.number,
            spacing=fl.spacing,
            format=fl.format,
            format_unit=fl.format_unit)


class Degrees:
    """WCS axes with longitude axis in degrees."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.coords[0].set_format_unit(u.degree)


class Hours:
    """WCS axes with longitude axis in hour angle."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.coords[0].set_format_unit(u.hourangle)


class Globe(AutoScaledWCSAxes):

    def __init__(self, *args, center='0d 0d', rotate=None, **kwargs):
        center = SkyCoord(center).icrs
        header = {
            'NAXIS': 2,
            'NAXIS1': 180,
            'NAXIS2': 180,
            'CRPIX1': 90.5,
            'CRPIX2': 90.5,
            'CRVAL1': center.ra.deg,
            'CRVAL2': center.dec.deg,
            'CDELT1': -2 / np.pi,
            'CDELT2': 2 / np.pi,
            'CTYPE1': self._xcoord + '-SIN',
            'CTYPE2': self._ycoord + '-SIN',
            'RADESYS': self._radesys}
        if rotate is not None:
            header['LONPOLE'] = u.Quantity(rotate).to_value(u.deg)
        super().__init__(
            *args, frame_class=EllipticalFrame, header=header, **kwargs)


class Zoom(AutoScaledWCSAxes):

    def __init__(self, *args, center='0d 0d', radius='1 deg', rotate=None,
                 **kwargs):
        center = SkyCoord(center).icrs
        radius = u.Quantity(radius).to(u.deg).value
        header = {
            'NAXIS': 2,
            'NAXIS1': 512,
            'NAXIS2': 512,
            'CRPIX1': 256.5,
            'CRPIX2': 256.5,
            'CRVAL1': center.ra.deg,
            'CRVAL2': center.dec.deg,
            'CDELT1': -radius / 256,
            'CDELT2': radius / 256,
            'CTYPE1': self._xcoord + '-TAN',
            'CTYPE2': self._ycoord + '-TAN',
            'RADESYS': self._radesys}
        if rotate is not None:
            header['LONPOLE'] = u.Quantity(rotate).to_value(u.deg)
        super().__init__(*args, header=header, **kwargs)


class AllSkyAxes(AutoScaledWCSAxes):
    """Base class for a multi-purpose all-sky projection."""

    def __init__(self, *args, **kwargs):
        header = {
            'NAXIS': 2,
            'NAXIS1': 360,
            'NAXIS2': 180,
            'CRPIX1': 180.5,
            'CRPIX2': 90.5,
            'CRVAL1': self._crval1,
            'CRVAL2': 0.0,
            'CDELT1': -2 * np.sqrt(2) / np.pi,
            'CDELT2': 2 * np.sqrt(2) / np.pi,
            'CTYPE1': self._xcoord + '-' + self._wcsprj,
            'CTYPE2': self._ycoord + '-' + self._wcsprj,
            'RADESYS': self._radesys}
        super().__init__(
            *args, frame_class=EllipticalFrame, header=header, **kwargs)
        self.coords[0].set_ticks(spacing=45 * u.deg)
        self.coords[1].set_ticks(spacing=30 * u.deg)
        self.coords[0].set_ticklabel(exclude_overlapping=True)
        self.coords[1].set_ticklabel(exclude_overlapping=True)


class Aitoff(AllSkyAxes):
    _wcsprj = 'AIT'


class Mollweide(AllSkyAxes):
    _wcsprj = 'MOL'


moddict = globals()

#
# Create subclasses and register all projections:
# '{astro|geo} {hours|degrees} {aitoff|globe|mollweide|zoom}'
#
bases1 = (Astro, Geo)
bases2 = (Hours, Degrees)
bases3 = (Aitoff, Globe, Mollweide, Zoom)
for bases in product(bases1, bases2, bases3):
    class_name = ''.join(cls.__name__ for cls in bases) + 'Axes'
    projection = ' '.join(cls.__name__.lower() for cls in bases)
    new_class = type(class_name, bases, {'name': projection})
    projection_registry.register(new_class)
    moddict[class_name] = new_class
    __all__.append(class_name)

#
# Create some synonyms:
# 'astro' will be short for 'astro hours',
# 'geo' will be short for 'geo degrees'
#
for base1, base2 in zip(bases1, bases2):
    for base3 in (Aitoff, Globe, Mollweide, Zoom):
        bases = (base1, base2, base3)
        orig_class_name = ''.join(cls.__name__ for cls in bases) + 'Axes'
        orig_class = moddict[orig_class_name]
        class_name = ''.join(cls.__name__ for cls in (base1, base3)) + 'Axes'
        projection = ' '.join(cls.__name__.lower() for cls in (base1, base3))
        new_class = type(class_name, (orig_class,), {'name': projection})
        projection_registry.register(new_class)
        moddict[class_name] = new_class
        __all__.append(class_name)

del class_name, moddict, projection, projection_registry, new_class
__all__ = tuple(__all__)

ligo/skymap/plot/angle.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Angle utilities."""
import numpy as np

__all__ = ('reference_angle', 'reference_angle_deg',
           'wrapped_angle', 'wrapped_angle_deg')


def reference_angle(a):
    """Convert an angle to a reference angle between -pi and pi."""
    a = np.mod(a, 2 * np.pi)
    return np.where(a <= np.pi, a, a - 2 * np.pi)


def reference_angle_deg(a):
    """Convert an angle to a reference angle between -180 and 180 degrees."""
    a = np.mod(a, 360)
    return np.where(a <= 180, a, a - 360)


def wrapped_angle(a):
    """Convert an angle to a reference angle between 0 and 2*pi."""
    return np.mod(a, 2 * np.pi)


def wrapped_angle_deg(a):
    """Convert an angle to a reference angle between 0 and 2*pi."""
    return np.mod(a, 360)

ligo/skymap/plot/backdrop.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
#
# Copyright (C) 2017-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Backdrops for astronomical plots."""

import warnings

from astropy.io import fits
from astropy.time import Time
from astropy.utils.data import download_file
from astropy.wcs import WCS
from matplotlib.image import imread
import numpy as np
from PIL.Image import DecompressionBombWarning
from reproject import reproject_interp

__all__ = ('bluemarble', 'blackmarble', 'mellinger', 'reproject_interp_rgb')


def big_imread(*args, **kwargs):
    """Wrapper for imread() that suppresses warnings when loading very large
    images (usually tiffs). Most of the all-sky images that we use in this
    module are large enough to trigger this warning:

        DecompressionBombWarning: Image size (91125000 pixels) exceeds limit of
        89478485 pixels, could be decompression bomb DOS attack.
    """
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', DecompressionBombWarning)
        img = imread(*args, **kwargs)
    return img


def mellinger():
    """Get the Mellinger Milky Way panorama.

    Retrieve, cache, and return the Mellinger Milky Way panorama. See
    http://www.milkywaysky.com.

    Returns
    -------
    `astropy.io.fits.ImageHDU`
        A FITS WCS image in ICRS coordinates.

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from astropy.visualization import (ImageNormalize,
                                           AsymmetricPercentileInterval)
        from astropy.wcs import WCS
        from matplotlib import pyplot as plt
        from ligo.skymap.plot import mellinger
        from reproject import reproject_interp

        ax = plt.axes(projection='astro hours aitoff')
        backdrop = mellinger()
        backdrop_wcs = WCS(backdrop.header).dropaxis(-1)
        interval = AsymmetricPercentileInterval(45, 98)
        norm = ImageNormalize(backdrop.data, interval)
        backdrop_reprojected = np.asarray([
            reproject_interp((layer, backdrop_wcs), ax.header)[0]
            for layer in norm(backdrop.data)])
        backdrop_reprojected = np.rollaxis(backdrop_reprojected, 0, 3)
        ax.imshow(backdrop_reprojected)

    """
    url = 'http://galaxy.phy.cmich.edu/~axel/mwpan2/mwpan2_RGB_3600.fits'
    hdu, = fits.open(url, cache=True)
    return hdu


def bluemarble(t, resolution='low'):
    """Get the "Blue Marble" image.

    Retrieve, cache, and return the NASA/NOAO/NPP "Blue Marble" image showing
    landforms and oceans.

    See https://visibleearth.nasa.gov/view.php?id=74117.

    Parameters
    ----------
    t : `astropy.time.Time`
        Time to embed in the WCS header.
    resolution : {'low', 'high'}
        Specify which version to use: the "low" resolution version (5400x2700
        pixels, the default) or the "high" resolution version (21600x10800
        pixels).

    Returns
    -------
    `astropy.io.fits.ImageHDU`
        A FITS WCS image in ICRS coordinates.

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from matplotlib import pyplot as plt
        from ligo.skymap.plot import bluemarble, reproject_interp_rgb

        obstime = '2017-08-17 12:41:04'
        ax = plt.axes(projection='geo degrees aitoff', obstime=obstime)
        ax.imshow(reproject_interp_rgb(bluemarble(obstime), ax.header))

    """
    variants = {
        'low': '5400x2700',
        'high': '21600x10800'
    }

    url = ('https://eoimages.gsfc.nasa.gov/images/imagerecords/74000/74117/'
           'world.200408.3x{}.png'.format(variants[resolution]))
    img = big_imread(download_file(url, cache=True))
    height, width, ndim = img.shape
    gmst_deg = Time(t).sidereal_time('mean', 'greenwich').deg
    header = fits.Header(dict(
        NAXIS=3,
        NAXIS1=ndim, NAXIS2=width, NAXIS3=height,
        CRPIX2=width / 2, CRPIX3=height / 2,
        CRVAL2=gmst_deg % 360, CRVAL3=0,
        CDELT2=360 / width,
        CDELT3=-180 / height,
        CTYPE2='RA---CAR',
        CTYPE3='DEC--CAR',
        RADESYSa='ICRS').items())
    return fits.ImageHDU(img[:, :, :], header)


def blackmarble(t, resolution='low'):
    """Get the "Black Marble" image.

    Get the NASA/NOAO/NPP image showing city lights, at the sidereal time given
    by t. See https://visibleearth.nasa.gov/view.php?id=79765.

    Parameters
    ----------
    t : `astropy.time.Time`
        Time to embed in the WCS header.
    resolution : {'low', 'mid', 'high'}
        Specify which version to use: the "low" resolution version (3600x1800
        pixels, the default), the "mid" resolution version (13500x6750 pixels),
        or the "high" resolution version (54000x27000 pixels).

    Returns
    -------
    `astropy.io.fits.ImageHDU`
        A FITS WCS image in ICRS coordinates.

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from matplotlib import pyplot as plt
        from ligo.skymap.plot import blackmarble, reproject_interp_rgb

        obstime = '2017-08-17 12:41:04'
        ax = plt.axes(projection='geo degrees aitoff', obstime=obstime)
        ax.imshow(reproject_interp_rgb(blackmarble(obstime), ax.header))

    """
    variants = {
        'low': '3600x1800',
        'high': '13500x6750',
        'mid': '54000x27000'
    }

    url = ('http://eoimages.gsfc.nasa.gov/images/imagerecords/79000/79765/'
           'dnb_land_ocean_ice.2012.{}_geo.tif'.format(variants[resolution]))
    img = big_imread(download_file(url, cache=True))
    height, width, ndim = img.shape
    gmst_deg = Time(t).sidereal_time('mean', 'greenwich').deg
    header = fits.Header(dict(
        NAXIS=3,
        NAXIS1=ndim, NAXIS2=width, NAXIS3=height,
        CRPIX2=width / 2, CRPIX3=height / 2,
        CRVAL2=gmst_deg % 360, CRVAL3=0,
        CDELT2=360 / width,
        CDELT3=-180 / height,
        CTYPE2='RA---CAR',
        CTYPE3='DEC--CAR',
        RADESYSa='ICRS').items())
    return fits.ImageHDU(img[:, :, :], header)


def reproject_interp_rgb(input_data, *args, **kwargs):
    data = input_data.data
    wcs = WCS(input_data.header).celestial
    return np.moveaxis(np.stack([
        reproject_interp((data[:, :, i], wcs),
                         *args, **kwargs)[0].astype(data.dtype)
        for i in range(3)]), 0, -1)

ligo/skymap/plot/cmap.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
#
# Copyright (C) 2014-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Register some extra Matplotlib color maps"""

try:
    from importlib import resources
except ImportError:
    # FIXME: remove after dropping support for Python < 3.7
    import importlib_resources as resources

from matplotlib import cm
from matplotlib import colors
import numpy as np

__all__ = ()


for name in ['cylon']:
    # Read in color map RGB data.
    with resources.open_text(__package__, name + '.csv') as f:
        data = np.loadtxt(f, delimiter=',')

    # Create color map.
    cmap = colors.LinearSegmentedColormap.from_list(name, data)
    # Assign in module.
    locals().update({name: cmap})
    # Register with Matplotlib.
    cm.register_cmap(cmap=cmap)

    # Generate reversed color map.
    name += '_r'
    data = data[::-1]
    cmap = colors.LinearSegmentedColormap.from_list(name, data)
    # Assign in module.
    locals().update({name: cmap})
    # Register with Matplotlib.
    cm.register_cmap(cmap=cmap)

ligo/skymap/plot/cylon.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
#
# Copyright (C) 2014-2018  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""RGB data for the "Cylon red" color map.

A print- and screen-friendly color map designed specifically for plotting
LSC/Virgo sky maps. The color map is constructed in CIE Lab space, following
a linear ramp in lightness (the `l` coordinate) and a cubic spline in color
components (the `a` and `b` coordinates).

This particular color map was selected from 20 random realizations of this
construction."""

if __name__ == '__main__':  # pragma: no cover
    from colormath.color_conversions import convert_color
    from colormath.color_objects import LabColor, sRGBColor
    from scipy.interpolate import interp1d
    import numpy as np

    def lab_to_rgb(*args):
        """Convert Lab color to sRGB, with components clipped to (0, 1)."""
        Lab = LabColor(*args)
        sRGB = convert_color(Lab, sRGBColor)
        return np.clip(sRGB.get_value_tuple(), 0, 1)

    L_samples = np.linspace(100, 0, 5)

    a_samples = (
        33.34664938,
        98.09940562,
        84.48361516,
        76.62970841,
        21.43276891)

    b_samples = (
        62.73345997,
        2.09003022,
        37.28252236,
        76.22507582,
        16.24862535)

    L = np.linspace(100, 0, 255)
    a = interp1d(L_samples, a_samples[::-1], 'cubic')(L)
    b = interp1d(L_samples, b_samples[::-1], 'cubic')(L)

    for line in __doc__.splitlines():
        print('#', line)
    for L, a, b in zip(L, a, b):
        print(*lab_to_rgb(L, a, b), sep=',')

ligo/skymap/plot/itrs_frame_monkeypatch.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
#
# Copyright (C) 2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Workaround for https://github.com/astropy/astropy/pull/9609."""
from astropy.coordinates import ITRS, SphericalRepresentation
from astropy.wcs.utils import _wcs_to_celestial_frame_builtin
from astropy.wcs.utils import WCS_FRAME_MAPPINGS


def wcs_to_celestial_frame(*args, **kwargs):
    frame = _wcs_to_celestial_frame_builtin(*args, **kwargs)
    if isinstance(frame, ITRS):
        frame = ITRS(obstime=frame.obstime,
                     representation_type=SphericalRepresentation)
    return frame


def install():
    WCS_FRAME_MAPPINGS[0] = [wcs_to_celestial_frame]

ligo/skymap/plot/marker.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
#
# Copyright (C) 2016-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Specialized markers."""

from matplotlib.path import Path
import numpy as np

__all__ = ('earth', 'reticle')


earth = Path.unit_circle()
verts = np.concatenate([earth.vertices, [[-1, 0], [1, 0], [0, -1], [0, 1]]])
codes = np.concatenate([earth.codes, [Path.MOVETO, Path.LINETO] * 2])
earth = Path(verts, codes)
del verts, codes
earth.__doc__ = """
The Earth symbol (circle and cross).

Examples
--------
.. plot::
   :context: reset
   :include-source:
   :align: center

    from matplotlib import pyplot as plt
    from ligo.skymap.plot.marker import earth

    plt.plot(0, 0, markersize=20, markeredgewidth=2,
             markerfacecolor='none', marker=earth)

"""


def reticle(inner=0.5, outer=1.0, angle=0.0, which='lrtb'):
    """Create a reticle or crosshairs marker.

    Parameters
    ----------
    inner : float
        Distance from the origin to the inside of the crosshairs.
    outer : float
        Distance from the origin to the outside of the crosshairs.
    angle : float
        Rotation in degrees; 0 for a '+' orientation and 45 for 'x'.

    Returns
    -------
    path : `matplotlib.path.Path`
        The new marker path, suitable for passing to Matplotlib functions
        (e.g., `plt.plot(..., marker=reticle())`)

    Examples
    --------
    .. plot::
       :context: reset
       :include-source:
       :align: center

        from matplotlib import pyplot as plt
        from ligo.skymap.plot.marker import reticle

        markers = [reticle(inner=0),
                   reticle(which='lt'),
                   reticle(which='lt', angle=45)]

        fig, ax = plt.subplots(figsize=(6, 2))
        ax.set_xlim(-0.5, 2.5)
        ax.set_ylim(-0.5, 0.5)
        for x, marker in enumerate(markers):
            ax.plot(x, 0, markersize=20, markeredgewidth=2, marker=marker)

    """
    angle = np.deg2rad(angle)
    x = np.cos(angle)
    y = np.sin(angle)
    rotation = [[x, y], [-y, x]]
    vertdict = {'l': [-1, 0], 'r': [1, 0], 'b': [0, -1], 't': [0, 1]}
    verts = [vertdict[direction] for direction in which]
    codes = [Path.MOVETO, Path.LINETO] * len(verts)
    verts = np.dot(verts, rotation)
    verts = np.swapaxes([inner * verts, outer * verts], 0, 1).reshape(-1, 2)
    return Path(verts, codes)

ligo/skymap/plot/poly.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Plotting tools for drawing polygons."""
import numpy as np
import healpy as hp

from .angle import reference_angle, wrapped_angle

__all__ = ('subdivide_vertices', 'cut_dateline',
           'cut_prime_meridian', 'make_rect_poly')


def subdivide_vertices(vertices, subdivisions):
    """Subdivide a list of vertices by inserting subdivisions additional
    vertices between each original pair of vertices using linear
    interpolation.
    """
    subvertices = np.empty((subdivisions * len(vertices), vertices.shape[1]))
    frac = np.atleast_2d(
        np.arange(subdivisions + 1, dtype=float) / subdivisions).T.repeat(
            vertices.shape[1], 1)
    for i in range(len(vertices)):
        subvertices[i * subdivisions:(i + 1) * subdivisions] = \
            frac[:0:-1, :] * \
            np.expand_dims(vertices[i - 1, :], 0).repeat(subdivisions, 0) + \
            frac[:-1, :] * \
            np.expand_dims(vertices[i, :], 0).repeat(subdivisions, 0)
    return subvertices


def cut_dateline(vertices):
    """Cut a polygon across the dateline, possibly splitting it into multiple
    polygons. Vertices consist of (longitude, latitude) pairs where longitude
    is always given in terms of a reference angle (between -π and π).

    This routine is not meant to cover all possible cases; it will only work
    for convex polygons that extend over less than a hemisphere.
    """
    vertices = vertices.copy()
    vertices[:, 0] += np.pi
    vertices = cut_prime_meridian(vertices)
    for v in vertices:
        v[:, 0] -= np.pi
    return vertices


def cut_prime_meridian(vertices):
    """Cut a polygon across the prime meridian, possibly splitting it into
    multiple polygons. Vertices consist of (longitude, latitude) pairs where
    longitude is always given in terms of a wrapped angle (between 0 and 2π).

    This routine is not meant to cover all possible cases; it will only work
    for convex polygons that extend over less than a hemisphere.
    """
    from shapely import geometry

    # Ensure that the list of vertices does not contain a repeated endpoint.
    if (vertices[0] == vertices[-1]).all():
        vertices = vertices[:-1]

    # Ensure that the longitudes are wrapped from 0 to 2π.
    vertices = np.column_stack((wrapped_angle(vertices[:, 0]), vertices[:, 1]))

    # Test if the segment consisting of points i-1 and i croses the meridian.
    #
    # If the two longitudes are in [0, 2π), then the shortest arc connecting
    # them crosses the meridian if the difference of the angles is greater
    # than π.
    phis = vertices[:, 0]
    phi0, phi1 = np.sort(np.row_stack((np.roll(phis, 1), phis)), axis=0)
    crosses_meridian = (phi1 - phi0 > np.pi)

    # Count the number of times that the polygon crosses the meridian.
    meridian_crossings = np.sum(crosses_meridian)

    if meridian_crossings == 0:
        # There were zero meridian crossings, so we can use the
        # original vertices as is.
        out_vertices = [vertices]
    elif meridian_crossings == 1:
        # There was one meridian crossing, so the polygon encloses the pole.
        # Any meridian-crossing edge has to be extended
        # into a curve following the nearest polar edge of the map.
        i, = np.flatnonzero(crosses_meridian)
        v0 = vertices[i - 1]
        v1 = vertices[i]

        # Find the latitude at which the meridian crossing occurs by
        # linear interpolation.
        delta_lon = abs(reference_angle(v1[0] - v0[0]))
        lat = (abs(reference_angle(v0[0])) / delta_lon * v0[1] +
               abs(reference_angle(v1[0])) / delta_lon * v1[1])

        # FIXME: Use this simple heuristic to decide which pole to enclose.
        sign_lat = np.sign(np.sum(vertices[:, 1]))

        # Find the closer of the left or the right map boundary for
        # each vertex in the line segment.
        lon_0 = 0. if v0[0] < np.pi else 2 * np.pi
        lon_1 = 0. if v1[0] < np.pi else 2 * np.pi

        # Set the output vertices to the polar cap plus the original
        # vertices.
        out_vertices = [
            np.vstack((
                vertices[:i],
                [[lon_0, lat],
                 [lon_0, sign_lat * np.pi / 2],
                 [lon_1, sign_lat * np.pi / 2],
                 [lon_1, lat]],
                vertices[i:]))]
    elif meridian_crossings == 2:
        # Since the polygon is assumed to be convex, if there is an even number
        # of meridian crossings, we know that the polygon does not enclose
        # either pole. Then we can use ordinary Euclidean polygon intersection
        # algorithms.

        out_vertices = []

        # Construct polygon representing map boundaries.
        frame_poly = geometry.Polygon(np.asarray([
            [0., 0.5 * np.pi],
            [0., -0.5 * np.pi],
            [2 * np.pi, -0.5 * np.pi],
            [2 * np.pi, 0.5 * np.pi]]))

        # Intersect with polygon re-wrapped to lie in [-π, π) or [π, 3π).
        for shift in [0, 2 * np.pi]:
            poly = geometry.Polygon(np.column_stack((
                reference_angle(vertices[:, 0]) + shift, vertices[:, 1])))
            intersection = poly.intersection(frame_poly)
            if intersection:
                assert isinstance(intersection, geometry.Polygon)
                assert intersection.is_simple
                out_vertices += [np.asarray(intersection.exterior)]
    else:
        # There were more than two intersections. Not implemented!
        raise NotImplementedError('The polygon intersected the map boundaries '
                                  'two or more times, so it is probably not '
                                  'simple and convex.')

    # Done!
    return out_vertices


def make_rect_poly(width, height, theta, phi, subdivisions=10):
    """Create a Polygon patch representing a rectangle with half-angles width
    and height rotated from the north pole to (theta, phi).
    """
    # Convert width and height to radians, then to Cartesian coordinates.
    w = np.sin(np.deg2rad(width))
    h = np.sin(np.deg2rad(height))

    # Generate vertices of rectangle.
    v = np.asarray([[-w, -h], [w, -h], [w, h], [-w, h]])

    # Subdivide.
    v = subdivide_vertices(v, subdivisions)

    # Project onto sphere by calculating z-coord from normalization condition.
    v = np.hstack((v, np.sqrt(1. - np.expand_dims(np.square(v).sum(1), 1))))

    # Transform vertices.
    v = np.dot(v, hp.rotator.euler_matrix_new(phi, theta, 0, Y=True))

    # Convert to spherical polar coordinates.
    thetas, phis = hp.vec2ang(v)

    # Return list of vertices as longitude, latitude pairs.
    return np.column_stack((wrapped_angle(phis), 0.5 * np.pi - thetas))

ligo/skymap/plot/pp.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Axes subclass for making probability--probability (P--P) plots.

Example
-------
You can create new P--P plot axes by passing the keyword argument
``projection='pp_plot'`` when creating new Matplotlib axes.

.. plot::
   :context: reset
   :include-source:
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    import numpy as np

    n = 100
    p_values_1 = np.random.uniform(size=n) # One experiment
    p_values_2 = np.random.uniform(size=n) # Another experiment
    p_values_3 = np.random.uniform(size=n) # Yet another experiment

    fig = plt.figure(figsize=(5, 5))
    ax = fig.add_subplot(111, projection='pp_plot')
    ax.add_confidence_band(n, alpha=0.95) # Add 95% confidence band
    ax.add_diagonal() # Add diagonal line
    ax.add_lightning(n, 20) # Add some random realizations of n samples
    ax.add_series(p_values_1, p_values_2, p_values_3) # Add our data

Or, you can call the constructor of `PPPlot` directly.

.. plot::
   :context: reset
   :include-source:
   :align: center

    from ligo.skymap.plot import PPPlot
    from matplotlib import pyplot as plt
    import numpy as np

    n = 100

    rect = [0.1, 0.1, 0.8, 0.8] # Where to place axes in figure
    fig = plt.figure(figsize=(5, 5))
    ax = PPPlot(fig, rect)
    fig.add_axes(ax)
    ax.add_confidence_band(n, alpha=0.95)
    ax.add_lightning(n, 20)
    ax.add_diagonal()

"""
import matplotlib
from matplotlib import axes
from matplotlib.projections import projection_registry
import scipy.stats
import numpy as np

__all__ = ('PPPlot',)


class PPPlot(axes.Axes):
    """Construct a probability--probability (P--P) plot."""

    name = 'pp_plot'

    def __init__(self, *args, **kwargs):
        # Call parent constructor
        super().__init__(*args, **kwargs)

        # Square axes, limits from 0 to 1
        self.set_aspect(1.0)
        self.set_xlim(0.0, 1.0)
        self.set_ylim(0.0, 1.0)

    @staticmethod
    def _make_series(p_values):
        for ps in p_values:
            if np.ndim(ps) == 1:
                ps = np.sort(np.atleast_1d(ps))
                n = len(ps)
                xs = np.concatenate(([0.], ps, [1.]))
                ys = np.concatenate(([0.], np.arange(1, n + 1) / n, [1.]))
            elif np.ndim(ps) == 2:
                xs = np.concatenate(([0.], ps[0], [1.]))
                ys = np.concatenate(([0.], ps[1], [1.]))
            else:
                raise ValueError('All series must be 1- or 2-dimensional')
            yield xs
            yield ys

    def add_series(self, *p_values, **kwargs):
        """Add a series of P-values to the plot.

        Parameters
        ----------
        p_values : `numpy.ndarray`
            One or more lists of P-values.

            If an entry in the list is one-dimensional, then it is interpreted
            as an unordered list of P-values. The ranked values will be plotted
            on the horizontal axis, and the cumulative fraction will be plotted
            on the vertical axis.

            If an entry in the list is two-dimensional, then the first subarray
            is plotted on the horizontal axis and the second subarray is
            plotted on the vertical axis.

        drawstyle : {'steps', 'lines', 'default'}
            Plotting style. If ``steps``, then plot steps to represent a
            piecewise constant function. If ``lines``, then connect points with
            straight lines. If ``default`` then use steps if there are more
            than 2 pixels per data point, or else lines.

        Other parameters
        ----------------
        kwargs :
            optional extra arguments to `matplotlib.axes.Axes.plot`

        """
        # Construct sequence of x, y pairs to pass to plot()
        args = list(self._make_series(p_values))
        min_n = min(len(ps) for ps in p_values)

        # Make copy of kwargs to pass to plot()
        kwargs = dict(kwargs)
        ds = kwargs.pop('drawstyle', 'default')
        if (ds == 'default' and 2 * min_n > self.bbox.width) or ds == 'lines':
            kwargs['drawstyle'] = 'default'
        else:
            kwargs['drawstyle'] = 'steps-post'

        return self.plot(*args, **kwargs)

    def add_worst(self, *p_values):
        """Mark the point at which the deviation is largest.

        Parameters
        ----------
        p_values : `numpy.ndarray`
            Same as in `add_series`.

        """
        series = list(self._make_series(p_values))
        for xs, ys in zip(series[0::2], series[1::2]):
            i = np.argmax(np.abs(ys - xs))
            x = xs[i]
            y = ys[i]
            if y == x:
                continue
            self.plot([x, x, 0], [0, y, y], '--', color='black', linewidth=0.5)
            if y < x:
                self.plot([x, y], [y, y], '-', color='black', linewidth=1)
                self.text(
                    x, y, ' {0:.02f} '.format(np.around(x - y, 2)),
                    ha='left', va='top')
            else:
                self.plot([x, x], [x, y], '-', color='black', linewidth=1)
                self.text(
                    x, y, ' {0:.02f} '.format(np.around(y - x, 2)),
                    ha='right', va='bottom')

    def add_diagonal(self, *args, **kwargs):
        """Add a diagonal line to the plot, running from (0, 0) to (1, 1).

        Other parameters
        ----------------
        kwargs :
            optional extra arguments to `matplotlib.axes.Axes.plot`

        """
        # Make copy of kwargs to pass to plot()
        kwargs = dict(kwargs)
        kwargs.setdefault('color', 'black')
        kwargs.setdefault('linestyle', 'dashed')
        kwargs.setdefault('linewidth', 0.5)

        # Plot diagonal line
        return self.plot([0, 1], [0, 1], *args, **kwargs)

    def add_lightning(self, nsamples, ntrials, **kwargs):
        """Add P-values drawn from a random uniform distribution, as a visual
        representation of the acceptable scatter about the diagonal.

        Parameters
        ----------
        nsamples : int
            Number of P-values in each trial
        ntrials : int
            Number of line series to draw.

        Other parameters
        ----------------
        kwargs :
            optional extra arguments to `matplotlib.axes.Axes.plot`

        """
        # Draw random samples
        args = np.random.uniform(size=(ntrials, nsamples))

        # Make copy of kwargs to pass to plot()
        kwargs = dict(kwargs)
        kwargs.setdefault('color', 'black')
        kwargs.setdefault('alpha', 0.5)
        kwargs.setdefault('linewidth', 0.25)

        # Plot series
        return self.add_series(*args, **kwargs)

    def add_confidence_band(
            self, nsamples, alpha=0.95, annotate=True, **kwargs):
        """Add a target confidence band.

        Parameters
        ----------
        nsamples : int
            Number of P-values
        alpha : float, default: 0.95
            Confidence level
        annotate : bool, optional, default: True
            If True, then label the confidence band.

        Other parameters
        ----------------
        **kwargs :
            optional extra arguments to `matplotlib.axes.Axes.fill_betweenx`

        """
        n = nsamples
        k = np.arange(0, n + 1)
        p = k / n
        ci_lo, ci_hi = scipy.stats.beta.interval(alpha, k + 1, n - k + 1)

        # Make copy of kwargs to pass to fill_betweenx()
        kwargs = dict(kwargs)
        kwargs.setdefault('color', 'lightgray')
        kwargs.setdefault('edgecolor', 'gray')
        kwargs.setdefault('linewidth', 0.5)
        fontsize = kwargs.pop('fontsize', 'x-small')

        if annotate:
            percent_sign = r'\%' if matplotlib.rcParams['text.usetex'] else '%'
            label = 'target {0:g}{1:s}\nconfidence band'.format(
                100 * alpha, percent_sign)

            self.annotate(
                label,
                xy=(1, 1),
                xytext=(0, 0),
                xycoords='axes fraction',
                textcoords='offset points',
                annotation_clip=False,
                horizontalalignment='right',
                verticalalignment='bottom',
                fontsize=fontsize,
                arrowprops=dict(
                    arrowstyle="->",
                    shrinkA=0, shrinkB=2, linewidth=0.5,
                    connectionstyle="angle,angleA=0,angleB=45,rad=0"))

        return self.fill_betweenx(p, ci_lo, ci_hi, **kwargs)

    @classmethod
    def _as_mpl_axes(cls):
        """Support placement in figure using the `projection` keyword argument.

        See http://matplotlib.org/devel/add_new_projection.html.
        """
        return cls, {}


projection_registry.register(PPPlot)

ligo/skymap/plot/util.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
#
# Copyright (C) 2012-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Miscellaneous plotting utilities."""
import matplotlib
from matplotlib import text
from matplotlib import ticker
from matplotlib import patheffects

__all__ = ('colorbar', 'outline_text')


def colorbar(*args):
    from matplotlib import pyplot as plt

    usetex = matplotlib.rcParams['text.usetex']
    locator = ticker.AutoLocator()
    formatter = ticker.ScalarFormatter(useMathText=not usetex)
    formatter.set_scientific(True)
    formatter.set_powerlimits((1e-1, 100))

    # Plot colorbar
    cb = plt.colorbar(*args,
                      orientation='horizontal', shrink=0.4,
                      ticks=locator, format=formatter)

    if cb.orientation == 'vertical':
        axis = cb.ax.yaxis
    else:
        axis = cb.ax.xaxis

    # Move order of magnitude text into last label.
    ticklabels = [label.get_text() for label in axis.get_ticklabels()]
    # Avoid putting two '$' next to each other if we are in tex mode.
    if usetex:
        fmt = '{{{0}}}{{{1}}}'
    else:
        fmt = '{0}{1}'
    ticklabels[-1] = fmt.format(ticklabels[-1], formatter.get_offset())
    axis.set_ticklabels(ticklabels)
    last_ticklabel = axis.get_ticklabels()[-1]
    last_ticklabel.set_horizontalalignment('left')

    # Draw edges in colorbar bands to correct thin white bands that
    # appear in buggy PDF viewers. See:
    # https://github.com/matplotlib/matplotlib/pull/1301
    cb.solids.set_edgecolor("face")

    # Done.
    return cb


def outline_text(ax):
    """Add a white outline to all text to make it stand out from the
    background.
    """
    effects = [patheffects.withStroke(linewidth=2, foreground='w')]
    for artist in ax.findobj(text.Text):
        artist.set_path_effects(effects)

ligo/skymap/postprocess/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/postprocess/contour.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import astropy_healpix as ah
from astropy import units as u
import healpy as hp
import numpy as np

__all__ = ('contour', 'simplify')


def _norm_squared(vertices):
    return np.sum(np.square(vertices), -1)


def _adjacent_triangle_area_squared(vertices):
    return 0.25 * _norm_squared(np.cross(
        np.roll(vertices, -1, axis=0) - vertices,
        np.roll(vertices, +1, axis=0) - vertices))


def _vec2radec(vertices, degrees=False):
    theta, phi = hp.vec2ang(np.asarray(vertices))
    ret = np.column_stack((phi % (2 * np.pi), 0.5 * np.pi - theta))
    if degrees:
        ret = np.rad2deg(ret)
    return ret


def simplify(vertices, min_area):
    """Simplify a polygon on the unit sphere.

    This is a naive, slow implementation of Visvalingam's algorithm (see
    http://bost.ocks.org/mike/simplify/), adapted for for linear rings on a
    sphere.

    Parameters
    ----------
    vertices : `np.ndarray`
        An Nx3 array of Cartesian vertex coordinates. Each vertex should be a
        unit vector.

    min_area : float
        The minimum area of triangles formed by adjacent triplets of vertices.

    Returns
    -------
    vertices : `np.ndarray`

    """
    area_squared = _adjacent_triangle_area_squared(vertices)
    min_area_squared = np.square(min_area)

    while True:
        i_min_area = np.argmin(area_squared)
        if area_squared[i_min_area] > min_area_squared:
            break

        vertices = np.delete(vertices, i_min_area, axis=0)
        area_squared = np.delete(area_squared, i_min_area)
        new_area_squared = _adjacent_triangle_area_squared(vertices)
        area_squared = np.maximum(area_squared, new_area_squared)

    return vertices


# A synonym for ``simplify`` to avoid aliasing by the keyword argument of the
# same name below.
_simplify = simplify


def contour(m, levels, nest=False, degrees=False, simplify=True):
    """Calculate contours from a HEALPix dataset.

    Parameters
    ----------
    m : `numpy.ndarray`
        The HEALPix dataset.
    levels : list
        The list of contour values.
    nest : bool, default=False
        Indicates whether the input sky map is in nested rather than
        ring-indexed HEALPix coordinates (default: ring).
    degrees : bool, default=False
        Whether the contours are in degrees instead of radians.
    simplify : bool, default=True
        Whether to simplify the paths.

    Returns
    -------
    list
        A list with the same length as `levels`.
        Each item is a list of disjoint polygons, of which each item is a
        list of points, of which each is a list consisting of the right
        ascension and declination.

    Examples
    --------
    A very simply example sky map...

    >>> nside = 32
    >>> npix = ah.nside_to_npix(nside)
    >>> ra, dec = hp.pix2ang(nside, np.arange(npix), lonlat=True)
    >>> m = dec
    >>> contour(m, [10, 20, 30], degrees=True)
    [[[[..., ...], ...], ...], ...]

    """
    # Infrequently used import
    import networkx as nx

    # Determine HEALPix resolution.
    npix = len(m)
    nside = ah.npix_to_nside(npix)
    min_area = 0.4 * ah.nside_to_pixel_area(nside).to_value(u.sr)

    neighbors = hp.get_all_neighbours(nside, np.arange(npix), nest=nest).T

    # Loop over the requested contours.
    paths = []
    for level in levels:

        # Find credible region.
        indicator = (m >= level)

        # Find all faces that lie on the boundary.
        # This speeds up the doubly nested ``for`` loop below by allowing us to
        # skip the vast majority of faces that are on the interior or the
        # exterior of the contour.
        tovisit = np.flatnonzero(
            np.any(indicator.reshape(-1, 1) !=
                   indicator[neighbors[:, ::2]], axis=1))

        # Construct a graph of the edges of the contour.
        graph = nx.Graph()
        face_pairs = set()
        for ipix1 in tovisit:
            neighborhood = neighbors[ipix1]
            for _ in range(4):
                neighborhood = np.roll(neighborhood, 2)
                ipix2 = neighborhood[4]

                # Skip this pair of faces if we have already examined it.
                new_face_pair = frozenset((ipix1, ipix2))
                if new_face_pair in face_pairs:
                    continue
                face_pairs.add(new_face_pair)

                # Determine if this pair of faces are on a boundary of the
                # credible level.
                if indicator[ipix1] == indicator[ipix2]:
                    continue

                # Add the common edge of this pair of faces.
                # Label each vertex with the set of faces that they share.
                graph.add_edge(
                    frozenset((ipix1, *neighborhood[2:5])),
                    frozenset((ipix1, *neighborhood[4:7])))
        graph = nx.freeze(graph)

        # Find contours by detecting cycles in the graph.
        cycles = nx.cycle_basis(graph)

        # Construct the coordinates of the vertices by averaging the
        # coordinates of the connected faces.
        cycles = [[
            np.sum(hp.pix2vec(nside, [i for i in v if i != -1], nest=nest), 1)
            for v in cycle] for cycle in cycles]

        # Simplify paths if requested.
        if simplify:
            cycles = [_simplify(cycle, min_area) for cycle in cycles]
            cycles = [cycle for cycle in cycles if len(cycle) > 2]

        # Convert to angles.
        cycles = [
            _vec2radec(cycle, degrees=degrees).tolist() for cycle in cycles]

        # Add to output paths.
        paths.append([cycle + [cycle[0]] for cycle in cycles])

    return paths

ligo/skymap/postprocess/cosmology.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
#
# Copyright (C) 2013-2020  Leo Singer, Rainer Corley
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Cosmology-related utilities.

All functions in this module use the Planck15 cosmological parameters.
"""

import numpy as np
import astropy.cosmology
import astropy.units as u

cosmo = astropy.cosmology.default_cosmology.get_cosmology_from_string(
    'Planck15')


def dVC_dVL_for_z(z):
    r"""Ratio, :math:`\mathrm{d}V_C / \mathrm{d}V_L`, between the comoving
    volume element and a naively Euclidean volume element in luminosity
    distance space; given as a function of redshift.

    Given the differential comoving volume per unit redshift,
    :math:`\mathrm{d}V_C / \mathrm{d}z`, and the derivative of luminosity
    distance in terms of redshift, :math:`\mathrm{d}D_L / \mathrm{d}z`, this is
    expressed as:

    .. math::

       \frac{\mathrm{d}V_C}{\mathrm{d}V_L} =
       \frac{\mathrm{d}V_C}{\mathrm{d}z}
       \left(
       {D_L}^2 \frac{\mathrm{d}D_L}{\mathrm{d}z}
       \right)^{-1}.
    """
    Ok0 = cosmo.Ok0
    DH = cosmo.hubble_distance
    DM_by_DH = (cosmo.comoving_transverse_distance(z) / DH).value
    DC_by_DH = (cosmo.comoving_distance(z) / DH).value
    zplus1 = z + 1.0
    if Ok0 == 0.0:
        ret = 1.0
    elif Ok0 > 0.0:
        ret = np.cosh(np.sqrt(Ok0) * DC_by_DH)
    else:  # Ok0 < 0.0 or Ok0 is nan
        ret = np.cos(np.sqrt(-Ok0) * DC_by_DH)
    ret *= zplus1
    ret += DM_by_DH * cosmo.efunc(z)
    ret *= np.square(zplus1)
    return 1.0 / ret


@np.vectorize
def z_for_DL(DL):
    """Redshift as a function of luminosity distance in Mpc."""
    return astropy.cosmology.z_at_value(
        cosmo.luminosity_distance, DL * u.Mpc)


def dVC_dVL_for_DL(DL):
    """Same as :meth:`dVC_dVL_for_z`, but as a function of luminosity
    distance.
    """
    return dVC_dVL_for_z(z_for_DL(DL))

ligo/skymap/postprocess/crossmatch.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Catalog cross matching for HEALPix sky maps."""
from collections import namedtuple

import astropy_healpix as ah
from astropy.coordinates import ICRS, SkyCoord, SphericalRepresentation
from astropy import units as u
import healpy as hp
import numpy as np

from .. import distance
from .. import moc

from .cosmology import dVC_dVL_for_DL

__all__ = ('crossmatch', 'CrossmatchResult')


def flood_fill(nside, ipix, m, nest=False):
    """Stack-based flood fill algorithm in HEALPix coordinates.

    Based on <http://en.wikipedia.org/w/index.php?title=Flood_fill&oldid=566525693#Alternative_implementations>.
    """  # noqa: E501
    # Initialize stack with starting pixel index.
    stack = [ipix]
    while stack:
        # Pop last pixel off of the stack.
        ipix = stack.pop()
        # Is this pixel in need of filling?
        if m[ipix]:
            # Fill in this pixel.
            m[ipix] = False
            # Find the pixels neighbors.
            neighbors = hp.get_all_neighbours(nside, ipix, nest=nest)
            # All pixels have up to 8 neighbors. If a pixel has less than 8
            # neighbors, then some entries of the array are set to -1. We
            # have to skip those.
            neighbors = neighbors[neighbors != -1]
            # Push neighboring pixels onto the stack.
            stack.extend(neighbors)


def count_modes(m, nest=False):
    """Count the number of modes in a binary HEALPix image by repeatedly
    applying the flood-fill algorithm.

    WARNING: The input array is clobbered in the process.
    """
    npix = len(m)
    nside = ah.npix_to_nside(npix)
    for nmodes in range(npix):
        nonzeroipix = np.flatnonzero(m)
        if len(nonzeroipix):
            flood_fill(nside, nonzeroipix[0], m, nest=nest)
        else:
            break
    return nmodes


def count_modes_moc(uniq, i):
    n = len(uniq)
    mask = np.concatenate((np.ones(i + 1, dtype=bool),
                           np.zeros(n - i - 1, dtype=bool)))
    sky_map = np.rec.fromarrays((uniq, mask), names=('UNIQ', 'MASK'))
    sky_map = moc.rasterize(sky_map)['MASK']
    return count_modes(sky_map, nest=True)


def cos_angle_distance(theta0, phi0, theta1, phi1):
    """Cosine of angular separation in radians between two points on the
    unit sphere.
    """
    cos_angle_distance = (
        np.cos(phi1 - phi0) * np.sin(theta0) * np.sin(theta1) +
        np.cos(theta0) * np.cos(theta1))
    return np.clip(cos_angle_distance, -1, 1)


def angle_distance(theta0, phi0, theta1, phi1):
    """Angular separation in radians between two points on the unit sphere."""
    return np.arccos(cos_angle_distance(theta0, phi0, theta1, phi1))


# Class to hold return value of find_injection method
CrossmatchResult = namedtuple(
    'CrossmatchResult',
    'searched_area searched_prob offset searched_modes contour_areas '
    'area_probs contour_modes searched_prob_dist contour_dists '
    'searched_vol searched_prob_vol contour_vols probdensity probdensity_vol')
"""Cross match result as returned by
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.

Notes
-----
 - All probabilities returned are between 0 and 1.
 - All angles returned are in degrees.
 - All areas returned are in square degrees.
 - All distances are luminosity distances in units of Mpc.
 - All volumes are in units of Mpc³. If :func:`.crossmatch` was run with
   ``cosmology=False``, then all volumes are Euclidean volumes in luminosity
   distance. If :func:`.crossmatch` was run with ``cosmology=True``, then all
   volumes are comoving volumes.

"""
_same_length_as_coordinates = ''' \
Same length as the `coordinates` argument passed to \
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.'''
_same_length_as_contours = ''' \
of the probabilities specified by the `contour` argument passed to \
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.'''
_same_length_as_areas = ''' \
of the areas specified by the `areas` argument passed to
:func:`~ligo.skymap.postprocess.crossmatch.crossmatch`.'''
CrossmatchResult.searched_area.__doc__ = '''\
Area within the 2D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.searched_prob.__doc__ = '''\
Probability within the 2D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.offset.__doc__ = '''\
Angles on the sky between the target positions and the maximum a posteriori \
position.''' + _same_length_as_coordinates
CrossmatchResult.searched_modes.__doc__ = '''\
Number of disconnected regions within the 2D credible regions \
containing each target position.''' + _same_length_as_coordinates
CrossmatchResult.contour_areas.__doc__ = '''\
Area within the 2D credible regions''' + _same_length_as_contours
CrossmatchResult.area_probs.__doc__ = '''\
Probability within the 2D credible regions''' + _same_length_as_areas
CrossmatchResult.contour_modes.__doc__ = '''\
Number of disconnected regions within the 2D credible \
regions''' + _same_length_as_contours
CrossmatchResult.searched_prob_dist.__doc__ = '''\
Cumulative CDF of distance, marginalized over sky position, at the distance \
of each of the targets.''' + _same_length_as_coordinates
CrossmatchResult.contour_dists.__doc__ = '''\
Distance credible interval, marginalized over sky \
position,''' + _same_length_as_coordinates
CrossmatchResult.searched_vol.__doc__ = '''\
Volume within the 3D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.searched_prob_vol.__doc__ = '''\
Probability within the 3D credible region containing each target \
position.''' + _same_length_as_coordinates
CrossmatchResult.contour_vols.__doc__ = '''\
Volume within the 3D credible regions''' + _same_length_as_contours
CrossmatchResult.probdensity.__doc__ = '''\
2D probability density per steradian at the positions of each of the \
targets.''' + _same_length_as_coordinates
CrossmatchResult.probdensity_vol.__doc__ = '''\
3D probability density per cubic megaparsec at the positions of each of the \
targets.''' + _same_length_as_coordinates


def crossmatch(sky_map, coordinates=None,
               contours=(), areas=(), modes=False, cosmology=False):
    """Cross match a sky map with a catalog of points.

    Given a sky map and the true right ascension and declination (in radians),
    find the smallest area in deg^2 that would have to be searched to find the
    source, the smallest posterior mass, and the angular offset in degrees from
    the true location to the maximum (mode) of the posterior. Optionally, also
    compute the areas of and numbers of modes within the smallest contours
    containing a given total probability.

    Parameters
    ----------
    sky_map : :class:`astropy.table.Table`
        A multiresolution sky map, as returned by
        :func:`ligo.skymap.io.fits.read_sky_map` called with the keyword
        argument ``moc=True``.

    coordinates : :class:`astropy.coordinates.SkyCoord`, optional
        The catalog of target positions to match against.

    contours : :class:`tuple`, optional
        Credible levels between 0 and 1. If this argument is present, then
        calculate the areas and volumes of the 2D and 3D credible regions that
        contain these probabilities. For example, for ``contours=(0.5, 0.9)``,
        then areas and volumes of the 50% and 90% credible regions.

    areas : :class:`tuple`, optional
        Credible areas in square degrees. If this argument is present, then
        calculate the probability contained in the 2D credible levels that have
        these areas. For example, for ``areas=(20, 100)``, then compute the
        probability within the smallest credible levels of 20 deg² and 100
        deg², respectively.

    modes : :class:`bool`, optional
        If True, then enable calculation of the number of distinct modes or
        islands of probability. Note that this option may be computationally
        expensive.

    cosmology : :class:`bool`, optional
        If True, then search space by descending probability density per unit
        comoving volume. If False, then search space by descending probability
        per luminosity distance cubed.

    Returns
    -------
    result : :class:`~ligo.skymap.postprocess.crossmatch.CrossmatchResult`

    Notes
    -----
    This function is also be used for injection finding; see
    :doc:`/tool/ligo_skymap_stats`.

    Examples
    --------
    First, some imports:

    >>> from astroquery.vizier import VizierClass
    >>> from astropy.coordinates import SkyCoord
    >>> from ligo.skymap.io import read_sky_map
    >>> from ligo.skymap.postprocess import crossmatch

    Next, retrieve the GLADE catalog using Astroquery and get the coordinates
    of all its entries:

    >>> vizier = VizierClass(
    ...     row_limit=-1, columns=['GWGC', '_RAJ2000', '_DEJ2000', 'Dist'])
    >>> cat, = vizier.get_catalogs('VII/281/glade2')
    >>> coordinates = SkyCoord(cat['_RAJ2000'], cat['_DEJ2000'], cat['Dist'])

    Load the multiresolution sky map for S190814bv:

    >>> url = 'https://gracedb.ligo.org/api/superevents/S190814bv/files/bayestar.multiorder.fits'
    >>> skymap = read_sky_map(url, moc=True)

    Perform the cross match:

    >>> result = crossmatch(skymap, coordinates)

    Using the cross match results, we can list the galaxies within the 90%
    credible volume:

    >>> print(cat[result.searched_prob_vol < 0.9])
       GWGC          _RAJ2000             _DEJ2000               Dist
                       deg                  deg                  Mpc
    ---------- -------------------- -------------------- --------------------
       NGC0171   9.3396699999999999 -19.9342460000000017    57.56212553960000
           ---  20.2009090000000064 -31.1146050000000010   137.16022925600001
    ESO540-003   8.9144679999999994 -20.1252980000000008    49.07809291930000
           ---  10.6762720000000009 -21.7740819999999999   276.46938505499998
           ---  13.5855169999999994 -23.5523850000000010   138.44550704800000
           ---  20.6362969999999990 -29.9825149999999958   160.23313164900000
           ---  13.1923879999999993 -22.9750179999999986   236.96795954500001
           ---  11.7813630000000007 -24.3706470000000017   244.25031189699999
           ---  19.1711120000000008 -31.4339490000000019   152.13614001400001
           ---  13.6367060000000002 -23.4948789999999974   141.25162979500001
           ...                  ...                  ...                  ...
           ---  11.3517000000000010 -25.8596999999999966   335.73800000000000
           ---  11.2073999999999998 -25.7149000000000001   309.02999999999997
           ---  11.1875000000000000 -25.7503999999999991   295.12099999999998
           ---  10.8608999999999991 -25.6904000000000003   291.07200000000000
           ---  10.6938999999999975 -25.6778300000000002   323.59399999999999
           ---  15.4935000000000009 -26.0304999999999964   304.78899999999999
           ---  15.2794000000000008 -27.0410999999999966   320.62700000000001
           ---  14.8323999999999980 -27.0459999999999994   320.62700000000001
           ---  14.5341000000000005 -26.0949000000000026   307.61000000000001
           ---  23.1280999999999963 -31.1109199999999966   320.62700000000001
    Length = 1479 rows

    """  # noqa: E501
    # Astropy coordinates that are constructed without distance have
    # a distance field that is unity (dimensionless).
    if coordinates is None:
        true_ra = true_dec = true_dist = None
    else:
        # Ensure that coordinates are in proper frame and representation
        coordinates = SkyCoord(coordinates,
                               representation_type=SphericalRepresentation,
                               frame=ICRS)
        true_ra = coordinates.ra.rad
        true_dec = coordinates.dec.rad
        if np.any(coordinates.distance != 1):
            true_dist = coordinates.distance.to_value(u.Mpc)
        else:
            true_dist = None

    contours = np.asarray(contours)

    # Sort the pixels by descending posterior probability.
    sky_map = np.flipud(np.sort(sky_map, order='PROBDENSITY'))

    # Find the pixel that contains the injection.
    order, ipix = moc.uniq2nest(sky_map['UNIQ'])
    max_order = np.max(order)
    max_nside = ah.level_to_nside(max_order)
    max_ipix = ipix << np.int64(2 * (max_order - order))
    if true_ra is not None:
        true_theta = 0.5 * np.pi - true_dec
        true_phi = true_ra
        true_pix = hp.ang2pix(max_nside, true_theta, true_phi, nest=True)
        i = np.argsort(max_ipix)
        true_idx = i[np.digitize(true_pix, max_ipix[i]) - 1]

    # Find the angular offset between the mode and true locations.
    mode_theta, mode_phi = hp.pix2ang(
        ah.level_to_nside(order[0]), ipix[0], nest=True)
    if true_ra is None:
        offset = np.nan
    else:
        offset = np.rad2deg(
            angle_distance(true_theta, true_phi, mode_theta, mode_phi))

    # Calculate the cumulative area in deg2 and the cumulative probability.
    dA = moc.uniq2pixarea(sky_map['UNIQ'])
    dP = sky_map['PROBDENSITY'] * dA
    prob = np.cumsum(dP)
    area = np.cumsum(dA) * np.square(180 / np.pi)

    if true_ra is None:
        searched_area = searched_prob = probdensity = np.nan
    else:
        # Find the smallest area that would have to be searched to find
        # the true location.
        searched_area = area[true_idx]

        # Find the smallest posterior mass that would have to be searched to
        # find the true location.
        searched_prob = prob[true_idx]

        # Find the probability density.
        probdensity = sky_map['PROBDENSITY'][true_idx]

    # Find the contours of the given credible levels.
    contour_idxs = np.digitize(contours, prob) - 1

    # For each of the given confidence levels, compute the area of the
    # smallest region containing that probability.
    contour_areas = np.interp(
        contours, prob, area, left=0, right=4*180**2/np.pi).tolist()

    # For each listed area, find the probability contained within the
    # smallest credible region of that area.
    area_probs = np.interp(areas, area, prob, left=0, right=1).tolist()

    if modes:
        if true_ra is None:
            searched_modes = np.nan
        else:
            # Count up the number of modes in each of the given contours.
            searched_modes = count_modes_moc(sky_map['UNIQ'], true_idx)
        contour_modes = [
            count_modes_moc(sky_map['UNIQ'], i) for i in contour_idxs]
    else:
        searched_modes = np.nan
        contour_modes = np.nan

    # Distance stats now...
    if 'DISTMU' in sky_map.dtype.names:
        dP_dA = sky_map['PROBDENSITY']
        mu = sky_map['DISTMU']
        sigma = sky_map['DISTSIGMA']
        norm = sky_map['DISTNORM']

        # Set up distance grid.
        n_r = 1000
        distmean, _ = distance.parameters_to_marginal_moments(dP, mu, sigma)
        max_r = 6 * distmean
        if true_dist is not None and np.size(true_dist) != 0 \
                and np.max(true_dist) > max_r:
            max_r = np.max(true_dist)
        d_r = max_r / n_r

        # Calculate searched_prob_dist and contour_dists.
        r = d_r * np.arange(1, n_r)
        P_r = distance.marginal_cdf(r, dP, mu, sigma, norm)
        if true_dist is None:
            searched_prob_dist = np.nan
        else:
            searched_prob_dist = np.interp(true_dist, r, P_r, left=0, right=1)
        if len(contours) == 0:
            contour_dists = []
        else:
            lo, hi = np.interp(
                np.row_stack((
                    0.5 * (1 - contours),
                    0.5 * (1 + contours)
                )), P_r, r, left=0, right=np.inf)
            contour_dists = (hi - lo).tolist()

        # Calculate volume of each voxel, defined as the region within the
        # HEALPix pixel and contained within the two centric spherical shells
        # with radii (r - d_r / 2) and (r + d_r / 2).
        dV = (np.square(r) + np.square(d_r) / 12) * d_r * dA.reshape(-1, 1)

        # Calculate probability within each voxel.
        dP = np.exp(
            -0.5 * np.square(
                (r.reshape(1, -1) - mu.reshape(-1, 1)) / sigma.reshape(-1, 1)
            )
        ) * (dP_dA * norm / (sigma * np.sqrt(2 * np.pi))).reshape(-1, 1) * dV
        dP[np.isnan(dP)] = 0  # Suppress invalid values

        # Calculate probability density per unit volume.

        if cosmology:
            dV *= dVC_dVL_for_DL(r)
        dP_dV = dP / dV
        i = np.flipud(np.argsort(dP_dV.ravel()))

        P_flat = np.cumsum(dP.ravel()[i])
        V_flat = np.cumsum(dV.ravel()[i])

        contour_vols = np.interp(
            contours, P_flat, V_flat, left=0, right=np.inf).tolist()
        P = np.empty_like(P_flat)
        V = np.empty_like(V_flat)
        P[i] = P_flat
        V[i] = V_flat
        P = P.reshape(dP.shape)
        V = V.reshape(dV.shape)
        if true_dist is None:
            searched_vol = searched_prob_vol = probdensity_vol = np.nan
        else:
            i_radec = true_idx
            i_dist = np.digitize(true_dist, r) - 1
            probdensity_vol = dP_dV[i_radec, i_dist]
            searched_prob_vol = P[i_radec, i_dist]
            searched_vol = V[i_radec, i_dist]
    else:
        searched_vol = searched_prob_vol = searched_prob_dist \
            = probdensity_vol = np.nan
        contour_dists = [np.nan] * len(contours)
        contour_vols = [np.nan] * len(contours)

    # Done.
    return CrossmatchResult(
        searched_area, searched_prob, offset, searched_modes, contour_areas,
        area_probs, contour_modes, searched_prob_dist, contour_dists,
        searched_vol, searched_prob_vol, contour_vols, probdensity,
        probdensity_vol)

ligo/skymap/postprocess/ellipse.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#

import astropy_healpix as ah
from astropy import units as u
from astropy.wcs import WCS
import healpy as hp
import numpy as np

from .. import moc
from ..extern.numpy.quantile import quantile

__all__ = ('find_ellipse',)


def find_ellipse(prob, cl=90, projection='ARC', nest=False):
    """For a HEALPix map, find an ellipse that contains a given probability.

    The orientation is defined as the angle of the semimajor axis
    counterclockwise from west on the plane of the sky. If you think of the
    semimajor distance as the width of the ellipse, then the orientation is the
    clockwise rotation relative to the image x-axis. Equivalently, the
    orientation is the position angle of the semi-minor axis.

    These conventions match the definitions used in DS9 region files [1]_ and
    Aladin drawing commands [2]_.

    Parameters
    ----------
    prob : np.ndarray, astropy.table.Table
        The HEALPix probability map, either as a full rank explicit array
        or as a multi-order map.
    cl : float
        The desired credible level (default: 90).
    projection : str, optional
        The WCS projection (default: 'ARC', or zenithal equidistant).
        For a list of possible values, see the Astropy documentation [3]_.
    nest : bool
        HEALPix pixel ordering (default: False, or ring ordering).

    Returns
    -------
    ra : float
        The ellipse center right ascension in degrees.
    dec : float
        The ellipse center right ascension in degrees.
    a : float
        The lenth of the semimajor axis in degrees.
    b : float
        The length of the semiminor axis in degrees.
    pa : float
        The orientation of the ellipse axis on the plane of the sky in degrees.
    area : float
        The area of the ellipse in square degrees.

    Notes
    -----
    The center of the ellipse is the median a posteriori sky position. The
    length and orientation of the semi-major and semi-minor axes are measured
    as follows:

    1. The sky map is transformed to a WCS projection that may be specified by
       the caller. The default projection is ``ARC`` (zenithal equidistant), in
       which radial distances are proportional to the physical angular
       separation from the center point.
    2. A 1-sigma ellipse is estimated by calculating the covariance matrix in
       the projected image plane using three rounds of sigma clipping to reject
       distant outlier points.
    3. The 1-sigma ellipse is inflated until it encloses an integrated
       probability of ``cl`` (default: 90%).

    The function returns a tuple of the right ascension, declination,
    semi-major distance, semi-minor distance, and orientation angle, all in
    degrees.

    References
    ----------
    .. [1] http://ds9.si.edu/doc/ref/region.html
    .. [2] http://aladin.u-strasbg.fr/java/AladinScriptManual.gml#draw
    .. [3] http://docs.astropy.org/en/stable/wcs/index.html#supported-projections

    Examples
    --------
    **Example 1**

    First, we need some imports.

    >>> from astropy.io import fits
    >>> from astropy.utils.data import download_file
    >>> from astropy.wcs import WCS
    >>> import healpy as hp
    >>> from reproject import reproject_from_healpix
    >>> import subprocess

    Next, we download the BAYESTAR sky map for GW170817 from the
    LIGO Document Control Center.

    >>> url = 'https://dcc.ligo.org/public/0146/G1701985/001/bayestar.fits.gz'  # doctest: +SKIP
    >>> filename = download_file(url, cache=True, show_progress=False)  # doctest: +SKIP
    >>> _, healpix_hdu = fits.open(filename)  # doctest: +SKIP
    >>> prob = hp.read_map(healpix_hdu, verbose=False)  # doctest: +SKIP

    Then, we calculate ellipse and write it to a DS9 region file.

    >>> ra, dec, a, b, pa, area = find_ellipse(prob)  # doctest: +SKIP
    >>> print(*np.around([ra, dec, a, b, pa, area], 5))  # doctest: +SKIP
    195.03732 -19.29358 8.66545 1.1793 63.61698 32.07665
    >>> s = 'fk5;ellipse({},{},{},{},{})'.format(ra, dec, a, b, pa)  # doctest: +SKIP
    >>> open('ds9.reg', 'w').write(s)  # doctest: +SKIP

    Then, we reproject a small patch of the HEALPix map, and save it to a file.

    >>> wcs = WCS()  # doctest: +SKIP
    >>> wcs.wcs.ctype = ['RA---ARC', 'DEC--ARC']  # doctest: +SKIP
    >>> wcs.wcs.crval = [ra, dec]  # doctest: +SKIP
    >>> wcs.wcs.crpix = [128, 128]  # doctest: +SKIP
    >>> wcs.wcs.cdelt = [-0.1, 0.1]  # doctest: +SKIP
    >>> img, _ = reproject_from_healpix(healpix_hdu, wcs, [256, 256])  # doctest: +SKIP
    >>> img_hdu = fits.ImageHDU(img, wcs.to_header())  # doctest: +SKIP
    >>> img_hdu.writeto('skymap.fits')  # doctest: +SKIP

    Now open the image and region file in DS9. You should find that the ellipse
    encloses the probability hot spot. You can load the sky map and region file
    from the command line:

    .. code-block:: sh

        $ ds9 skymap.fits -region ds9.reg

    Or you can do this manually:

        1. Open DS9.
        2. Open the sky map: select "File->Open..." and choose ``skymap.fits``
           from the dialog box.
        3. Open the region file: select "Regions->Load Regions..." and choose
           ``ds9.reg`` from the dialog box.

    Now open the image and region file in Aladin.

        1. Open Aladin.
        2. Open the sky map: select "File->Load Local File..." and choose
           ``skymap.fits`` from the dialog box.
        3. Open the sky map: select "File->Load Local File..." and choose
           ``ds9.reg`` from the dialog box.

    You can also compare the original HEALPix file with the ellipse in Aladin:

        1. Open Aladin.
        2. Open the HEALPix file by pasting the URL from the top of this
           example in the Command field at the top of the window and hitting
           return, or by selecting "File->Load Direct URL...", pasting the URL,
           and clicking "Submit."
        3. Open the sky map: select "File->Load Local File..." and choose
           ``ds9.reg`` from the dialog box.

    **Example 2**

    This example shows that we get approximately the same answer for GW171087
    if we read it in as a multi-order map.

    >>> from ..io import read_sky_map  # doctest: +SKIP
    >>> skymap_moc = read_sky_map(healpix_hdu, moc=True)  # doctest: +SKIP
    >>> ellipse = find_ellipse(skymap_moc)  # doctest: +SKIP
    >>> print(*np.around(ellipse, 5))  # doctest: +SKIP
    195.03709 -19.27589 8.67611 1.18167 63.60454 32.08015

    **Example 3**

    I'm not showing the `ra` or `pa` output from the examples below because
    the right ascension is arbitary when dec=90° and the position angle is
    arbitrary when a=b; their arbitrary values may vary depending on your math
    library. Also, I add 0.0 to the outputs because on some platforms you tend
    to get values of dec or pa that get rounded to -0.0, which is within
    numerical precision but would break the doctests (see
    https://stackoverflow.com/questions/11010683).

    This is an example sky map that is uniform in sin(theta) out to a given
    radius in degrees. The 90% credible radius should be 0.9 * radius. (There
    will be deviations for small radius due to finite resolution.)

    >>> def make_uniform_in_sin_theta(radius, nside=512):
    ...     npix = ah.nside_to_npix(nside)
    ...     theta, phi = hp.pix2ang(nside, np.arange(npix))
    ...     theta_max = np.deg2rad(radius)
    ...     prob = np.where(theta <= theta_max, 1 / np.sin(theta), 0)
    ...     return prob / prob.sum()
    ...

    >>> prob = make_uniform_in_sin_theta(1)
    >>> ra, dec, a, b, pa, area = find_ellipse(prob)
    >>> dec, a, b, area  # doctest: +FLOAT_CMP
    (89.90862520480792, 0.8703361458208101, 0.8703357768874356, 2.3788811576269793)

    >>> prob = make_uniform_in_sin_theta(10)
    >>> ra, dec, a, b, pa, area = find_ellipse(prob)
    >>> dec, a, b, area  # doctest: +FLOAT_CMP
    (89.90827657529562, 9.024846562072119, 9.024842703023802, 255.11972196535515)

    >>> prob = make_uniform_in_sin_theta(120)
    >>> ra, dec, a, b, pa, area = find_ellipse(prob)
    >>> dec, a, b, area  # doctest: +FLOAT_CMP
    (90.0, 107.9745037610576, 107.97450376105758, 26988.70467497216)

    **Example 4**

    These are approximately Gaussian distributions.

    >>> from scipy import stats
    >>> def make_gaussian(mean, cov, nside=512):
    ...     npix = ah.nside_to_npix(nside)
    ...     xyz = np.transpose(hp.pix2vec(nside, np.arange(npix)))
    ...     dist = stats.multivariate_normal(mean, cov)
    ...     prob = dist.pdf(xyz)
    ...     return prob / prob.sum()
    ...

    This one is centered at RA=45°, Dec=0° and has a standard deviation of ~1°.

    >>> prob = make_gaussian(
    ...     [1/np.sqrt(2), 1/np.sqrt(2), 0],
    ...     np.square(np.deg2rad(1)))
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (45.0, 0.0, 2.1424077148886744, 2.1420790721225518, 90.0, 14.467701995920123)

    This one is centered at RA=45°, Dec=0°, and is elongated in the north-south
    direction.

    >>> prob = make_gaussian(
    ...     [1/np.sqrt(2), 1/np.sqrt(2), 0],
    ...     np.diag(np.square(np.deg2rad([1, 1, 10]))))
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (44.99999999999999, 0.0, 13.58768882719899, 2.0829846178241853, 90.0, 88.57796576937031)

    This one is centered at RA=0°, Dec=0°, and is elongated in the east-west
    direction.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     np.diag(np.square(np.deg2rad([1, 10, 1]))))
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 13.583918022027149, 2.0823769912401433, 0.0, 88.54622940628761)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    10° to the west of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 0.1, -0.15],
    ...      [0, -0.15, 1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.7713312709293, 33.50754131182681, 80.78231196786838, 6372.344658663038)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    10° to the east of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 0.1, 0.15],
    ...      [0, 0.15, 1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.77133127093047, 33.50754131182745, 99.21768803213159, 6372.344658663096)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    80° to the east of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 1, 0.15],
    ...      [0, 0.15, 0.1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.7756448603915, 33.509863018519894, 170.78252287327365, 6372.425731592412)

    This one is centered at RA=0°, Dec=0°, and has its long axis tilted about
    80° to the west of north.

    >>> prob = make_gaussian(
    ...     [1, 0, 0],
    ...     [[0.1, 0, 0],
    ...      [0, 1, -0.15],
    ...      [0, -0.15, 0.1]])
    ...
    >>> find_ellipse(prob)  # doctest: +FLOAT_CMP
    (0.0, 0.0, 64.77564486039148, 33.50986301851987, 9.217477126726322, 6372.42573159241)

    """  # noqa: E501
    try:
        prob['UNIQ']
    except (IndexError, KeyError, ValueError):
        npix = len(prob)
        nside = ah.npix_to_nside(npix)
        ipix = range(npix)
        area = ah.nside_to_pixel_area(nside).to_value(u.deg**2)
    else:
        order, ipix = moc.uniq2nest(prob['UNIQ'])
        nside = 1 << order.astype(int)
        ipix = ipix.astype(int)
        area = ah.nside_to_pixel_area(nside).to_value(u.sr)
        prob = prob['PROBDENSITY'] * area
        area *= np.square(180 / np.pi)
        nest = True

    # Find median a posteriori sky position.
    xyz0 = [quantile(x, 0.5, weights=prob)
            for x in hp.pix2vec(nside, ipix, nest=nest)]
    (ra,), (dec,) = hp.vec2ang(np.asarray(xyz0), lonlat=True)

    # Construct WCS with the specified projection
    # and centered on mean direction.
    w = WCS()
    w.wcs.crval = [ra, dec]
    w.wcs.ctype = ['RA---' + projection, 'DEC--' + projection]

    # Transform HEALPix to zenithal equidistant coordinates.
    xy = w.wcs_world2pix(
        np.transpose(
            hp.pix2ang(
                nside, ipix, nest=nest, lonlat=True)), 1)

    # Keep only values that were inside the projection.
    keep = np.logical_and.reduce(np.isfinite(xy), axis=1)
    xy = xy[keep]
    prob = prob[keep]
    if not np.isscalar(area):
        area = area[keep]

    # Find covariance matrix, performing three rounds of sigma-clipping
    # to reject outliers.
    keep = np.ones(len(xy), dtype=bool)
    for _ in range(3):
        c = np.cov(xy[keep], aweights=prob[keep], rowvar=False)
        nsigmas = np.sqrt(np.sum(xy.T * np.linalg.solve(c, xy.T), axis=0))
        keep &= (nsigmas < 3)

    # Find the number of sigma that enclose the cl% credible level.
    i = np.argsort(nsigmas)
    nsigmas = nsigmas[i]
    cls = np.cumsum(prob[i])
    if np.isscalar(area):
        careas = np.arange(1, len(i) + 1) * area
    else:
        careas = np.cumsum(area[i])
    nsigma = np.interp(1e-2 * cl, cls, nsigmas)
    area = np.interp(1e-2 * cl, cls, careas)

    # If the credible level is not within the projection,
    # then stop here and return all nans.
    if 1e-2 * cl > cls[-1]:
        return np.nan, np.nan, np.nan, np.nan, np.nan

    # Find the eigendecomposition of the covariance matrix.
    w, v = np.linalg.eigh(c)

    # Find the semi-minor and semi-major axes.
    b, a = nsigma * np.sqrt(w)

    # Find the position angle.
    pa = np.rad2deg(np.arctan2(*v[0]))

    # An ellipse is symmetric under rotations of 180°.
    # Return the smallest possible positive position angle.
    pa %= 180

    # Done!
    return ra, dec, a, b, pa, area

ligo/skymap/postprocess/util.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Postprocessing utilities for HEALPix sky maps."""

import astropy_healpix as ah
from astropy.coordinates import (CartesianRepresentation, SkyCoord,
                                 UnitSphericalRepresentation)
from astropy import units as u
import healpy as hp
import numpy as np

__all__ = ('find_greedy_credible_levels', 'smooth_ud_grade', 'posterior_mean',
           'posterior_max')


def find_greedy_credible_levels(p, ranking=None):
    """Find the greedy credible levels of a (possibly multi-dimensional) array.

    Parameters
    ----------
    p : np.ndarray
        The input array, typically a HEALPix image.

    ranking : np.ndarray, optional
        The array to rank in order to determine the greedy order.
        The default is `p` itself.

    Returns
    -------
    cls : np.ndarray
        An array with the same shape as `p`, with values ranging from `0`
        to `p.sum()`, representing the greedy credible level to which each
        entry in the array belongs.

    """
    p = np.asarray(p)
    pflat = p.ravel()
    if ranking is None:
        ranking = pflat
    else:
        ranking = np.ravel(ranking)
    i = np.flipud(np.argsort(ranking))
    cs = np.cumsum(pflat[i])
    cls = np.empty_like(pflat)
    cls[i] = cs
    return cls.reshape(p.shape)


def smooth_ud_grade(m, nside, nest=False):
    """Resample a sky map to a new resolution using bilinear interpolation.

    Parameters
    ----------
    m : np.ndarray
        The input HEALPix array.

    nest : bool, default=False
        Indicates whether the input sky map is in nested rather than
        ring-indexed HEALPix coordinates (default: ring).

    Returns
    -------
    new_m : np.ndarray
        The resampled HEALPix array. The sum of `m` is approximately preserved.

    """
    npix = ah.nside_to_npix(nside)
    theta, phi = hp.pix2ang(nside, np.arange(npix), nest=nest)
    new_m = hp.get_interp_val(m, theta, phi, nest=nest)
    return new_m * len(m) / len(new_m)


def posterior_mean(prob, nest=False):
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    xyz = hp.pix2vec(nside, np.arange(npix), nest=nest)
    mean_xyz = np.average(xyz, axis=1, weights=prob)
    pos = SkyCoord(*mean_xyz, representation_type=CartesianRepresentation)
    pos.representation_type = UnitSphericalRepresentation
    return pos


def posterior_max(prob, nest=False):
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    i = np.argmax(prob)
    return SkyCoord(
        *hp.pix2ang(nside, i, nest=nest, lonlat=True), unit=u.deg)

ligo/skymap/tool/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Functions that support the command line interface."""

import argparse
from distutils.dir_util import mkpath
from distutils.errors import DistutilsFileError
import glob
import inspect
import itertools
import logging
import os
import sys

import numpy as np

from ..util import sqlite
from .. import version

version_string = version.__package__ + ' ' + version.version


class FileType(argparse.FileType):
    """Inherit from :class:`argparse.FileType` to enable opening stdin or
    stdout in binary mode.

    This is a workaround for https://bugs.python.org/issue14156.
    """

    def __call__(self, string):
        if string == '-' and 'b' in self._mode:
            if 'r' in self._mode:
                return sys.stdin.buffer
            elif 'w' in self._mode:
                return sys.stdout.buffer
        return super().__call__(string)


class EnableAction(argparse.Action):

    def __init__(self,
                 option_strings,
                 dest,
                 default=True,
                 required=False,
                 help=None):
        opt, = option_strings
        if not opt.startswith('--enable-'):
            raise ValueError('Option string must start with --enable-')
        option_strings = [opt, opt.replace('--enable-', '--disable-')]
        super().__init__(
            option_strings,
            dest=dest,
            nargs=0,
            default=default,
            required=required,
            help=help)

    def __call__(self, parser, namespace, values, option_string):
        if option_string.startswith('--enable-'):
            setattr(namespace, self.dest, True)
        elif option_string.startswith('--disable-'):
            setattr(namespace, self.dest, False)
        else:
            raise RuntimeError('This code cannot be reached')


class GlobAction(argparse._StoreAction):
    """Generate a list of filenames from a list of filenames and globs."""

    def __call__(self, parser, namespace, values, *args, **kwargs):
        values = list(
            itertools.chain.from_iterable(glob.iglob(s) for s in values))
        if values:
            super().__call__(parser, namespace, values, *args, **kwargs)
        nvalues = getattr(namespace, self.dest)
        nvalues = 0 if nvalues is None else len(nvalues)
        if self.nargs == argparse.OPTIONAL:
            if nvalues > 1:
                msg = 'expected at most one file'
            else:
                msg = None
        elif self.nargs == argparse.ONE_OR_MORE:
            if nvalues < 1:
                msg = 'expected at least one file'
            else:
                msg = None
        elif self.nargs == argparse.ZERO_OR_MORE:
            msg = None
        elif int(self.nargs) != nvalues:
            msg = 'expected exactly %s file' % self.nargs
            if self.nargs != 1:
                msg += 's'
        else:
            msg = None
        if msg is not None:
            msg += ', but found '
            msg += '{} file'.format(nvalues)
            if nvalues != 1:
                msg += 's'
            raise argparse.ArgumentError(self, msg)


waveform_parser = argparse.ArgumentParser(add_help=False)
group = waveform_parser.add_argument_group(
    'waveform options', 'Options that affect template waveform generation')
# FIXME: The O1 uberbank high-mass template, SEOBNRv2_ROM_DoubleSpin, does
# not support frequencies less than 30 Hz.
group.add_argument(
    '--f-low', type=float, metavar='Hz', default=30,
    help='Low frequency cutoff')
group.add_argument(
    '--f-high-truncate', type=float, default=0.95,
    help='Truncate waveform at this fraction of the maximum frequency of the '
    'PSD')
group.add_argument(
    '--waveform', default='o2-uberbank',
    help='Template waveform approximant: e.g., TaylorF2threePointFivePN')
del group


prior_parser = argparse.ArgumentParser(add_help=False)
group = prior_parser.add_argument_group(
    'prior options', 'Options that affect the BAYESTAR likelihood')
group.add_argument(
    '--min-inclination', type=float, metavar='deg', default=0.0,
    help='Minimum inclination in degrees')
group.add_argument(
    '--max-inclination', type=float, metavar='deg', default=90.0,
    help='Maximum inclination in degrees')
group.add_argument(
    '--min-distance', type=float, metavar='Mpc',
    help='Minimum distance of prior in megaparsecs')
group.add_argument(
    '--max-distance', type=float, metavar='Mpc',
    help='Maximum distance of prior in megaparsecs')
group.add_argument(
    '--prior-distance-power', type=int, metavar='-1|2', default=2,
    help='Distance prior: -1 for uniform in log, 2 for uniform in volume')
group.add_argument(
    '--cosmology', action='store_true',
    help='Use cosmological comoving volume prior')
group.add_argument(
    '--enable-snr-series', action=EnableAction,
    help='Enable input of SNR time series')
del group


mcmc_parser = argparse.ArgumentParser(add_help=False)
group = mcmc_parser.add_argument_group(
    'BAYESTAR MCMC options', 'BAYESTAR options for MCMC sampling')
group.add_argument(
    '--mcmc', action='store_true',
    help='Use MCMC sampling instead of Gaussian quadrature')
group.add_argument(
    '--chain-dump', action='store_true',
    help='For MCMC methods, dump the sample chain to disk')
del group


class HelpChoicesAction(argparse.Action):

    def __init__(self,
                 option_strings,
                 choices=(),
                 dest=argparse.SUPPRESS,
                 default=argparse.SUPPRESS):
        name = option_strings[0].replace('--help-', '')
        super().__init__(
            option_strings=option_strings,
            dest=dest,
            default=default,
            nargs=0,
            help='show supported values for --' + name + ' and exit')
        self._name = name
        self._choices = choices

    def __call__(self, parser, namespace, values, option_string=None):
        print('Supported values for --' + self._name + ':')
        for choice in self._choices:
            print(choice)
        parser.exit()


def type_with_sideeffect(type):
    def decorator(sideeffect):
        def func(value):
            ret = type(value)
            sideeffect(ret)
            return ret
        return func
    return decorator


@type_with_sideeffect(str)
def loglevel_type(value):
    try:
        value = int(value)
    except ValueError:
        value = value.upper()
    logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
                        level=value)


class LogLevelAction(argparse._StoreAction):

    def __init__(
            self, option_strings, dest, nargs=None, const=None, default=None,
            type=None, choices=None, required=False, help=None, metavar=None):
        metavar = '|'.join(logging._levelToName.values())
        type = loglevel_type
        super().__init__(
            option_strings, dest, nargs=nargs, const=const, default=default,
            type=type, choices=choices, required=required, help=help,
            metavar=metavar)


@type_with_sideeffect(int)
def seed(value):
    np.random.seed(value)


random_parser = argparse.ArgumentParser(add_help=False)
group = random_parser.add_argument_group(
    'random number generator options',
    'Options that affect the Numpy pseudo-random number genrator')
group.add_argument(
    '--seed', type=seed, help='Pseudo-random number generator seed '
    '[default: initialized from /dev/urandom or clock]')


class HelpFormatter(argparse.RawDescriptionHelpFormatter,
                    argparse.ArgumentDefaultsHelpFormatter):
    pass


class ArgumentParser(argparse.ArgumentParser):
    """An ArgumentParser subclass with some sensible defaults.

    - Any ``.py`` suffix is stripped from the program name, because the
      program is probably being invoked from the stub shell script.

    - The description is taken from the docstring of the file in which the
      ArgumentParser is created.

    - If the description is taken from the docstring, then whitespace in
      the description is preserved.

    - A ``--version`` option is added that prints the version of ligo.skymap.
    """

    def __init__(self,
                 prog=None,
                 usage=None,
                 description=None,
                 epilog=None,
                 parents=[],
                 prefix_chars='-',
                 fromfile_prefix_chars=None,
                 argument_default=None,
                 conflict_handler='error',
                 add_help=True):
        parent_frame = inspect.currentframe().f_back
        if prog is None:
            prog = parent_frame.f_code.co_filename
            prog = os.path.basename(prog)
            prog = prog.replace('_', '-').replace('.py', '')
        if description is None:
            description = parent_frame.f_globals.get('__doc__', None)
        super().__init__(
            prog=prog,
            usage=usage,
            description=description,
            epilog=epilog,
            parents=parents,
            formatter_class=HelpFormatter,
            prefix_chars=prefix_chars,
            fromfile_prefix_chars=fromfile_prefix_chars,
            argument_default=argument_default,
            conflict_handler=conflict_handler,
            add_help=add_help)
        self.register('action', 'glob', GlobAction)
        self.register('action', 'loglevel', LogLevelAction)
        self.add_argument(
            '--version', action='version', version=version_string)
        self.add_argument(
            '-l', '--loglevel', action='loglevel', default='INFO')


class DirType:
    """Factory for directory arguments."""

    def __init__(self, create=False):
        self._create = create

    def __call__(self, string):
        if self._create:
            try:
                mkpath(string)
            except DistutilsFileError as e:
                raise argparse.ArgumentTypeError(e.message)
        else:
            try:
                os.listdir(string)
            except OSError as e:
                raise argparse.ArgumentTypeError(e)
        return string


class SQLiteType(FileType):
    """Open an SQLite database, or fail if it does not exist.

    Here is an example of trying to open a file that does not exist for
    reading (mode='r'). It should raise an exception:

    >>> import tempfile
    >>> filetype = SQLiteType('r')
    >>> filename = tempfile.mktemp()
    >>> # Note, simply check or a FileNotFound error in Python 3.
    >>> filetype(filename)
    Traceback (most recent call last):
      ...
    argparse.ArgumentTypeError: ...

    If the file already exists, then it's fine:

    >>> import sqlite3
    >>> filetype = SQLiteType('r')
    >>> with tempfile.NamedTemporaryFile() as f:
    ...     with sqlite3.connect(f.name) as db:
    ...         _ = db.execute('create table foo (bar char)')
    ...     filetype(f.name)
    <sqlite3.Connection object at ...>

    Here is an example of opening a file for writing (mode='w'), which should
    overwrite the file if it exists. Even if the file was not an SQLite
    database beforehand, this should work:

    >>> filetype = SQLiteType('w')
    >>> with tempfile.NamedTemporaryFile(mode='w') as f:
    ...     print('This is definitely not an SQLite file.', file=f)
    ...     f.flush()
    ...     with filetype(f.name) as db:
    ...         db.execute('create table foo (bar char)')
    <sqlite3.Cursor object at ...>

    Here is an example of opening a file for appending (mode='a'), which should
    NOT overwrite the file if it exists. If the file was not an SQLite database
    beforehand, this should raise an exception.

    >>> filetype = SQLiteType('a')
    >>> with tempfile.NamedTemporaryFile(mode='w') as f:
    ...     print('This is definitely not an SQLite file.', file=f)
    ...     f.flush()
    ...     with filetype(f.name) as db:
    ...         db.execute('create table foo (bar char)')
    Traceback (most recent call last):
      ...
    sqlite3.DatabaseError: ...

    And if the database did exist beforehand, then opening for appending
    (mode='a') should not clobber existing tables.

    >>> filetype = SQLiteType('a')
    >>> with tempfile.NamedTemporaryFile() as f:
    ...     with sqlite3.connect(f.name) as db:
    ...         _ = db.execute('create table foo (bar char)')
    ...     with filetype(f.name) as db:
    ...         db.execute('select count(*) from foo').fetchone()
    (0,)
    """

    def __init__(self, mode):
        if mode not in 'arw':
            raise ValueError('Unknown file mode: {}'.format(mode))
        self.mode = mode

    def __call__(self, string):
        try:
            return sqlite.open(string, self.mode)
        except OSError as e:
            raise argparse.ArgumentTypeError(e)


def _sanitize_arg_value_for_xmldoc(value):
    if hasattr(value, 'read'):
        return value.name
    elif isinstance(value, tuple):
        return tuple(_sanitize_arg_value_for_xmldoc(v) for v in value)
    elif isinstance(value, list):
        return [_sanitize_arg_value_for_xmldoc(v) for v in value]
    else:
        return value


def register_to_xmldoc(xmldoc, parser, opts, **kwargs):
    from ligo.lw.utils import process
    params = {key: _sanitize_arg_value_for_xmldoc(value)
              for key, value in opts.__dict__.items()}
    return process.register_to_xmldoc(
        xmldoc, parser.prog, params, **kwargs, version=version_string)


start_msg = '\
Waiting for input on stdin. Type control-D followed by a newline to terminate.'
stop_msg = 'Reached end of file. Exiting.'


def iterlines(file, start_message=start_msg, stop_message=stop_msg):
    """Iterate over non-emtpy lines in a file."""
    is_tty = os.isatty(file.fileno())

    if is_tty:
        print(start_message, file=sys.stderr)

    while True:
        # Read a line.
        line = file.readline()

        if not line:
            # If we reached EOF, then exit.
            break

        # Strip off the trailing newline and any whitespace.
        line = line.strip()

        # Emit the line if it is not empty.
        if line:
            yield line

    if is_tty:
        print(stop_message, file=sys.stderr)


def should_gzip(filename):
    _, ext = os.path.splitext(filename)
    return ext == '.gz'


def write_fileobj(xmldoc, f):
    import ligo.lw.utils

    with ligo.lw.utils.SignalsTrap():
        ligo.lw.utils.write_fileobj(xmldoc, f, gz=should_gzip(f.name))

ligo/skymap/tool/bayestar_inject.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Rough-cut injection tool.

The idea is to efficiently sample events, uniformly in "sensitive volume"
(differential comoving volume divided by 1 + z), and from a distribution of
masses and spins, such that later detection cuts will not reject an excessive
number of events.

This occurs in two steps. First, we divide the intrinsic parameter space into a
very coarse 10x10x10x10 grid and calculate the maximum horizon distance in each
grid cell. Second, we directly sample injections jointly from the mass and spin
distribution and a uniform and isotropic spatial distribution with a redshift
cutoff that is piecewise constant in the masses and spins.
"""

from functools import partial

from astropy import cosmology
from astropy.cosmology.core import vectorize_if_needed
from astropy import units
from astropy.units import dimensionless_unscaled
import lal
import numpy as np
from scipy.integrate import quad, fixed_quad
from scipy.interpolate import interp1d
from scipy.optimize import root_scalar
from scipy.ndimage import maximum_filter

from ..util import progress_map
from ..bayestar.filter import sngl_inspiral_psd
from . import (
    ArgumentParser, FileType, random_parser, register_to_xmldoc, write_fileobj)

lal.ClobberDebugLevel(lal.LALNDEBUG)


def get_decisive_snr(snrs):
    """Return the SNR for the trigger that decides if an event is detectable.

    If there are two or more detectors, then the decisive SNR is the SNR of the
    second loudest detector (since a coincidence of two or more events is
    required). If there is only one detector, then the decisive SNR is just the
    SNR of that detector. If there are no detectors, then 0 is returned.

    Parameters
    ----------
    snrs : list
        List of SNRs (floats).

    Returns
    -------
    decisive_snr : float

    """
    if len(snrs) > 1:
        return sorted(snrs)[-2]
    elif len(snrs) == 1:
        return snrs[0]
    else:
        return 0.0


def lo_hi_nonzero(x):
    nonzero = np.flatnonzero(x)
    return nonzero[0], nonzero[-1]


def z_at_snr(cosmo, psds, waveform, f_low, snr, mass1, mass2, spin1z, spin2z):
    """
    Get redshift at which a waveform attains a given SNR.

    Parameters
    ----------
    cosmo : :class:`astropy.cosmology.FLRW`
        The cosmological model.
    psds : list
        List of :class:`lal.REAL8FrequencySeries` objects.
    waveform : str
        Waveform approximant name.
    f_low : float
        Low-frequency cutoff for template.
    snr : float
        Target SNR.
    params : list
        List of waveform parameters: mass1, mass2, spin1z, spin2z.

    Returns
    -------
    comoving_distance : float
        Comoving distance in Mpc.

    """
    # Construct waveform
    series = sngl_inspiral_psd(waveform, f_low=f_low,
                               mass1=mass1, mass2=mass2,
                               spin1z=spin1z, spin2z=spin2z)
    i_lo, i_hi = lo_hi_nonzero(series.data.data)
    log_f = np.log(series.f0 + series.deltaF * np.arange(i_lo, i_hi + 1))
    log_f_lo = log_f[0]
    log_f_hi = log_f[-1]
    num = interp1d(
        log_f, np.log(series.data.data[i_lo:i_hi + 1]),
        fill_value=-np.inf, bounds_error=False, assume_sorted=True)

    denoms = []
    for series in psds:
        i_lo, i_hi = lo_hi_nonzero(
            np.isfinite(series.data.data) & (series.data.data != 0))
        log_f = np.log(series.f0 + series.deltaF * np.arange(i_lo, i_hi + 1))
        denom = interp1d(
            log_f, log_f - np.log(series.data.data[i_lo:i_hi + 1]),
            fill_value=-np.inf, bounds_error=False, assume_sorted=True)
        denoms.append(denom)

    def snr_at_z(z):
        logzp1 = np.log(z + 1)
        integrand = lambda log_f: [
            np.exp(num(log_f + logzp1) + denom(log_f)) for denom in denoms]
        integrals, _ = fixed_quad(
            integrand, log_f_lo, log_f_hi - logzp1, n=1024)
        snr = get_decisive_snr(np.sqrt(4 * integrals))
        with np.errstate(divide='ignore'):
            snr /= cosmo.angular_diameter_distance(z).to_value(units.Mpc)
        return snr

    return root_scalar(lambda z: snr_at_z(z) - snr, bracket=(0, 1e3)).root


def get_max_z(cosmo, psds, waveform, f_low, snr, mass1, mass2, spin1z, spin2z,
              jobs=1):
    # Calculate the maximum distance on the grid.
    params = [mass1, mass2, spin1z, spin2z]
    result = list(progress_map(
        partial(z_at_snr, cosmo, psds, waveform, f_low, snr),
        *(param.ravel() for param in np.meshgrid(*params, indexing='ij')),
        jobs=jobs))
    result = np.reshape(result, tuple(len(param) for param in params))

    assert np.all(result >= 0), 'some redshifts are negative'
    assert np.all(np.isfinite(result)), 'some redshifts are not finite'
    return result


def _sensitive_volume_integral(cosmo, z):
    dh3_sr = cosmo.hubble_distance**3 / units.sr

    def integrand(z):
        result = cosmo.differential_comoving_volume(z)
        result /= (1 + z) * dh3_sr
        return result.to_value(dimensionless_unscaled)

    def integral(z):
        result, _ = quad(integrand, 0, z)
        return result

    return vectorize_if_needed(integral, z)


def sensitive_volume(cosmo, z):
    """Sensitive volume :math:`V(z)` out to redshift :math:`z`.

    Given a population of events that occur at a constant rate density
    :math:`R` per unit comoving volume per unit proper time, the number of
    observed events out to a redshift :math:`N(z)` over an observation time
    :math:`T` is :math:`N(z) = R T V(z)`.
    """
    dh3 = cosmo.hubble_distance**3
    return 4 * np.pi * dh3 * _sensitive_volume_integral(cosmo, z)


def sensitive_distance(cosmo, z):
    r"""Sensitive distance as a function of redshift :math:`z`.

    The sensitive distance is the distance :math:`d_s(z)` defined such that
    :math:`V(z) = 4/3\pi {d_s(z)}^3`, where :math:`V(z)` is the sensitive
    volume.
    """
    dh = cosmo.hubble_distance
    return dh * np.cbrt(3 * _sensitive_volume_integral(cosmo, z))


def cell_max(values):
    r"""
    Find pairwise max of consecutive elements across all axes of an array.

    Parameters
    ----------
    values : :class:`numpy.ndarray`
        An input array of :math:`n` dimensions,
        :math:`(m_0, m_1, \dots, m_{n-1})`.

    Returns
    -------
    maxima : :class:`numpy.ndarray`
        An input array of :math:`n` dimensions, each with a length 1 less than
        the input array,
        :math:`(m_0 - 1, m_1 - 1, \dots, m_{n-1} - 1)`.

    """
    maxima = maximum_filter(values, size=2, mode='constant')
    indices = (slice(1, None),) * np.ndim(values)
    return maxima[indices]


def assert_not_reached():  # pragma: no cover
    raise AssertionError('This line should not be reached.')


def parser():
    parser = ArgumentParser(parents=[random_parser])
    parser.add_argument(
        '--cosmology', choices=cosmology.parameters.available,
        default='Planck15', help='Cosmological model')
    parser.add_argument(
        '--distribution', required=True, choices=(
            'bns_astro', 'bns_broad', 'nsbh_astro', 'nsbh_broad',
            'bbh_astro', 'bbh_broad'))
    parser.add_argument(
        '--reference-psd', type=FileType('rb'), metavar='PSD.xml[.gz]',
        required=True, help='PSD file')
    parser.add_argument(
        '--f-low', type=float, default=25.0,
        help='Low frequency cutoff in Hz')
    parser.add_argument(
        '--min-snr', type=float, default=4.0,
        help='Minimum decisive SNR of injections given the reference PSDs')
    parser.add_argument(
        '--waveform', default='o2-uberbank',
        help='Waveform approximant')
    parser.add_argument(
        '--nsamples', type=int, default=100000,
        help='Output this many injections')
    parser.add_argument(
        '-o', '--output', type=FileType('wb'), default='-',
        metavar='INJ.xml[.gz]', help='Output file, optionally gzip-compressed')
    parser.add_argument(
        '-j', '--jobs', type=int, default=1, const=None, nargs='?',
        help='Number of threads')
    return parser


def main(args=None):
    from ligo.lw import lsctables
    from ligo.lw.utils import process as ligolw_process
    from ligo.lw import utils as ligolw_utils
    from ligo.lw import ligolw
    import lal.series
    from scipy import stats

    p = parser()
    args = p.parse_args(args)

    xmldoc = ligolw.Document()
    xmlroot = xmldoc.appendChild(ligolw.LIGO_LW())
    process = register_to_xmldoc(xmldoc, p, args)

    cosmo = cosmology.default_cosmology.get_cosmology_from_string(
        args.cosmology)

    ns_mass_min = 1.0
    ns_mass_max = 2.0
    bh_mass_min = 5.0
    bh_mass_max = 50.0

    ns_astro_spin_min = -0.05
    ns_astro_spin_max = +0.05
    ns_astro_mass_dist = stats.norm(1.33, 0.09)
    ns_astro_spin_dist = stats.uniform(
        ns_astro_spin_min, ns_astro_spin_max - ns_astro_spin_min)

    ns_broad_spin_min = -0.4
    ns_broad_spin_max = +0.4
    ns_broad_mass_dist = stats.uniform(ns_mass_min, ns_mass_max - ns_mass_min)
    ns_broad_spin_dist = stats.uniform(
        ns_broad_spin_min, ns_broad_spin_max - ns_broad_spin_min)

    bh_astro_spin_min = -0.99
    bh_astro_spin_max = +0.99
    bh_astro_mass_dist = stats.pareto(b=1.3)
    bh_astro_spin_dist = stats.uniform(
        bh_astro_spin_min, bh_astro_spin_max - bh_astro_spin_min)

    bh_broad_spin_min = -0.99
    bh_broad_spin_max = +0.99
    bh_broad_mass_dist = stats.reciprocal(bh_mass_min, bh_mass_max)
    bh_broad_spin_dist = stats.uniform(
        bh_broad_spin_min, bh_broad_spin_max - bh_broad_spin_min)

    if args.distribution.startswith('bns_'):
        m1_min = m2_min = ns_mass_min
        m1_max = m2_max = ns_mass_max
        if args.distribution.endswith('_astro'):
            x1_min = x2_min = ns_astro_spin_min
            x1_max = x2_max = ns_astro_spin_max
            m1_dist = m2_dist = ns_astro_mass_dist
            x1_dist = x2_dist = ns_astro_spin_dist
        elif args.distribution.endswith('_broad'):
            x1_min = x2_min = ns_broad_spin_min
            x1_max = x2_max = ns_broad_spin_max
            m1_dist = m2_dist = ns_broad_mass_dist
            x1_dist = x2_dist = ns_broad_spin_dist
        else:  # pragma: no cover
            assert_not_reached()
    elif args.distribution.startswith('nsbh_'):
        m1_min = bh_mass_min
        m1_max = bh_mass_max
        m2_min = ns_mass_min
        m2_max = ns_mass_max
        if args.distribution.endswith('_astro'):
            x1_min = bh_astro_spin_min
            x1_max = bh_astro_spin_max
            x2_min = ns_astro_spin_min
            x2_max = ns_astro_spin_max
            m1_dist = bh_astro_mass_dist
            m2_dist = ns_astro_mass_dist
            x1_dist = bh_astro_spin_dist
            x2_dist = ns_astro_spin_dist
        elif args.distribution.endswith('_broad'):
            x1_min = bh_broad_spin_min
            x1_max = bh_broad_spin_max
            x2_min = ns_broad_spin_min
            x2_max = ns_broad_spin_max
            m1_dist = bh_broad_mass_dist
            m2_dist = ns_broad_mass_dist
            x1_dist = bh_broad_spin_dist
            x2_dist = ns_broad_spin_dist
        else:  # pragma: no cover
            assert_not_reached()
    elif args.distribution.startswith('bbh_'):
        m1_min = m2_min = bh_mass_min
        m1_max = m2_max = bh_mass_max
        if args.distribution.endswith('_astro'):
            x1_min = x2_min = bh_astro_spin_min
            x1_max = x2_max = bh_astro_spin_max
            m1_dist = m2_dist = bh_astro_mass_dist
            x1_dist = x2_dist = bh_astro_spin_dist
        elif args.distribution.endswith('_broad'):
            x1_min = x2_min = bh_broad_spin_min
            x1_max = x2_max = bh_broad_spin_max
            m1_dist = m2_dist = bh_broad_mass_dist
            x1_dist = x2_dist = bh_broad_spin_dist
        else:  # pragma: no cover
            assert_not_reached()
    else:  # pragma: no cover
        assert_not_reached()

    dists = (m1_dist, m2_dist, x1_dist, x2_dist)

    # Read PSDs
    psds = list(
        lal.series.read_psd_xmldoc(
            ligolw_utils.load_fileobj(
                args.reference_psd,
                contenthandler=lal.series.PSDContentHandler)).values())

    # Construct mass1, mass2, spin1z, spin2z grid.
    m1 = np.geomspace(m1_min, m1_max, 10)
    m2 = np.geomspace(m2_min, m2_max, 10)
    x1 = np.linspace(x1_min, x1_max, 10)
    x2 = np.linspace(x2_min, x2_max, 10)
    params = m1, m2, x1, x2

    # Calculate the maximum distance on the grid.
    max_z = get_max_z(
        cosmo, psds, args.waveform, args.f_low, args.min_snr, m1, m2, x1, x2,
        jobs=args.jobs)
    max_distance = sensitive_distance(cosmo, max_z).to_value(units.Mpc)

    # Find piecewise constant approximate upper bound on distance.
    max_distance = cell_max(max_distance)

    # Calculate V * T in each grid cell
    cdfs = [dist.cdf(param) for param, dist in zip(params, dists)]
    cdf_los = [cdf[:-1] for cdf in cdfs]
    cdfs = [np.diff(cdf) for cdf in cdfs]
    probs = np.prod(np.meshgrid(*cdfs, indexing='ij'), axis=0)
    probs /= probs.sum()
    probs *= 4/3*np.pi*max_distance**3
    volume = probs.sum()
    probs /= volume
    probs = probs.ravel()

    volumetric_rate = args.nsamples / volume * units.year**-1 * units.Mpc**-3

    # Draw random grid cells
    dist = stats.rv_discrete(values=(np.arange(len(probs)), probs))
    indices = np.unravel_index(
        dist.rvs(size=args.nsamples), max_distance.shape)

    # Draw random intrinsic params from each cell
    cols = {}
    cols['mass1'], cols['mass2'], cols['spin1z'], cols['spin2z'] = [
        dist.ppf(stats.uniform(cdf_lo[i], cdf[i]).rvs(size=args.nsamples))
        for i, dist, cdf_lo, cdf in zip(indices, dists, cdf_los, cdfs)]

    # Draw random extrinsic parameters
    cols['distance'] = stats.powerlaw(a=3, scale=max_distance[indices]).rvs(
        size=args.nsamples)
    cols['longitude'] = stats.uniform(0, 2 * np.pi).rvs(
        size=args.nsamples)
    cols['latitude'] = np.arcsin(stats.uniform(-1, 2).rvs(
        size=args.nsamples))
    cols['inclination'] = np.arccos(stats.uniform(-1, 2).rvs(
        size=args.nsamples))
    cols['polarization'] = stats.uniform(0, 2 * np.pi).rvs(
        size=args.nsamples)
    cols['coa_phase'] = stats.uniform(-np.pi, 2 * np.pi).rvs(
        size=args.nsamples)
    cols['time_geocent'] = stats.uniform(1e9, units.year.to(units.second)).rvs(
        size=args.nsamples)

    # Convert from sensitive distance to redshift and comoving distance.
    # FIXME: Replace this brute-force lookup table with a solver.
    z = np.linspace(0, max_z.max(), 10000)
    ds = sensitive_distance(cosmo, z).to_value(units.Mpc)
    dc = cosmo.comoving_distance(z).to_value(units.Mpc)
    z_for_ds = interp1d(ds, z, kind='cubic', assume_sorted=True)
    dc_for_ds = interp1d(ds, dc, kind='cubic', assume_sorted=True)
    zp1 = 1 + z_for_ds(cols['distance'])
    cols['distance'] = dc_for_ds(cols['distance'])

    # Apply redshift factor to convert from comoving distance and source frame
    # masses to luminosity distance and observer frame masses.
    for key in ['distance', 'mass1', 'mass2']:
        cols[key] *= zp1

    # Populate sim_inspiral table
    sims = xmlroot.appendChild(lsctables.New(lsctables.SimInspiralTable))
    for row in zip(*cols.values()):
        sims.appendRow(
            **dict(
                dict.fromkeys(sims.validcolumns, None),
                process_id=process.process_id,
                simulation_id=sims.get_next_id(),
                waveform=args.waveform,
                f_lower=args.f_low,
                **dict(zip(cols.keys(), row))))

    # Record process end time.
    process.comment = str(volumetric_rate)
    ligolw_process.set_process_end_time(process)

    # Write output file.
    write_fileobj(xmldoc, args.output)

ligo/skymap/tool/bayestar_localize_coincs.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Produce GW sky maps for all coincidences in search pipeline output database
in LIGO-LW XML, LIGO-LW SQLite, or PyCBC HDF5 format.

The distance prior is controlled by the ``--prior-distance-power`` argument.
If you set ``--prior-distance-power`` to k, then the distance prior is
proportional to r^k. The default is 2, uniform in volume.

If the ``--min-distance`` argument is omitted, it defaults to zero. If the
``--max-distance argument`` is omitted, it defaults to the SNR=4 horizon
distance of the most sensitive detector.

A FITS file is created for each sky map, having a filename of the form
``X.fits`` where X is the integer LIGO-LW row ID of the coinc. The ``OBJECT``
card in the FITS header is also set to the integer row ID.
"""

from . import (
    ArgumentParser, FileType, mkpath,
    waveform_parser, prior_parser, mcmc_parser, random_parser)


ROW_ID_COMMENT = [
    '',
    'The integer value in the OBJECT card in this FITS header is a row ID',
    'that refers to a coinc_event table row in the input LIGO-LW document.',
    '']


def parser():
    parser = ArgumentParser(
        parents=[waveform_parser, prior_parser, mcmc_parser, random_parser])
    parser.add_argument(
        '-d', '--disable-detector', metavar='X1', type=str, nargs='+',
        help='disable certain detectors')
    parser.add_argument(
        '--keep-going', '-k', default=False, action='store_true',
        help='Keep processing events if a sky map fails to converge')
    parser.add_argument(
        'input', metavar='INPUT.{hdf,xml,xml.gz,sqlite}', default='-',
        nargs='+', type=FileType('rb'),
        help='Input LIGO-LW XML file, SQLite file, or PyCBC HDF5 files. '
             'For PyCBC, you must supply the coincidence file '
             '(e.g. "H1L1-HDFINJFIND.hdf" or "H1L1-STATMAP.hdf"), '
             'the template bank file (e.g. H1L1-BANK2HDF.hdf), '
             'the single-detector merged PSD files '
             '(e.g. "H1-MERGE_PSDS.hdf" and "L1-MERGE_PSDS.hdf"), '
             'and the single-detector merged trigger files '
             '(e.g. "H1-HDF_TRIGGER_MERGE.hdf" and '
             '"L1-HDF_TRIGGER_MERGE.hdf"), '
             'in any order.')
    parser.add_argument(
        '--pycbc-sample', default='foreground',
        help='(PyCBC only) sample population')
    parser.add_argument(
        '--coinc-event-id', type=int, nargs='*',
        help='run on only these specified events')
    parser.add_argument(
        '--output', '-o', default='.',
        help='output directory')
    parser.add_argument(
        '--condor-submit', action='store_true',
        help='submit to Condor instead of running locally')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    import logging
    log = logging.getLogger('BAYESTAR')

    # BAYESTAR imports.
    from .. import omp
    from ..io import fits, events
    from ..bayestar import localize

    # Other imports.
    import os
    from collections import OrderedDict
    import numpy as np
    import subprocess
    import sys

    # Squelch annoying and uninformative LAL log messages.
    import lal
    lal.ClobberDebugLevel(lal.LALNDEBUG)

    log.info('Using %d OpenMP thread(s)', omp.num_threads)

    # Read coinc file.
    log.info(
        '%s:reading input files', ','.join(file.name for file in opts.input))
    event_source = events.open(*opts.input, sample=opts.pycbc_sample)

    if opts.disable_detector:
        event_source = events.detector_disabled.open(
            event_source, opts.disable_detector)

    mkpath(opts.output)

    if opts.condor_submit:
        if opts.seed is not None:
            raise NotImplementedError(
                '--seed does not yet work with --condor-submit')
        if opts.coinc_event_id:
            raise ValueError(
                'must not set --coinc-event-id with --condor-submit')
        with subprocess.Popen(['condor_submit'],
                              # FIXME: use text=True instead in Python >= 3.7
                              encoding=sys.stdin.encoding,
                              stdin=subprocess.PIPE) as proc:
            f = proc.stdin
            print('''
                  accounting_group = ligo.dev.o3.cbc.pe.bayestar
                  on_exit_remove = (ExitBySignal == False) && (ExitCode == 0)
                  on_exit_hold = (ExitBySignal == True) || (ExitCode != 0)
                  on_exit_hold_reason = (ExitBySignal == True \
                    ? strcat("The job exited with signal ", ExitSignal) \
                    : strcat("The job exited with code ", ExitCode))
                  request_memory = 1000 MB
                  universe = vanilla
                  getenv = true
                  executable = /usr/bin/env
                  JobBatchName = BAYESTAR
                  environment = "OMP_NUM_THREADS=1"
                  ''', file=f)
            print('error =', os.path.join(opts.output, '$(cid).err'), file=f)
            print('arguments = "',
                  *(arg for arg in sys.argv if arg != '--condor-submit'),
                  '--coinc-event-id $(cid)"', file=f)
            print('queue cid in', *event_source, file=f)
        sys.exit(proc.returncode)

    if opts.coinc_event_id:
        event_source = OrderedDict(
            (key, event_source[key]) for key in opts.coinc_event_id)

    count_sky_maps_failed = 0

    # Loop over all sngl_inspiral <-> sngl_inspiral coincs.
    for coinc_event_id, event in event_source.items():
        # Loop over sky localization methods
        log.info('%d:computing sky map', coinc_event_id)
        if opts.chain_dump:
            chain_dump = f'{coinc_event_id}.hdf5'
        else:
            chain_dump = None
        try:
            sky_map = localize(
                event, opts.waveform, opts.f_low,
                np.deg2rad(opts.min_inclination),
                np.deg2rad(opts.max_inclination),
                opts.min_distance,
                opts.max_distance, opts.prior_distance_power,
                opts.cosmology, mcmc=opts.mcmc, chain_dump=chain_dump,
                enable_snr_series=opts.enable_snr_series,
                f_high_truncate=opts.f_high_truncate)
            sky_map.meta['objid'] = coinc_event_id
            sky_map.meta['comment'] = ROW_ID_COMMENT

        except (ArithmeticError, ValueError):
            log.exception('%d:sky localization failed', coinc_event_id)
            count_sky_maps_failed += 1
            if not opts.keep_going:
                raise
        else:
            log.info('%d:saving sky map', coinc_event_id)
            filename = f'{coinc_event_id}.fits'
            fits.write_sky_map(
                os.path.join(opts.output, filename), sky_map, nest=True)

    if count_sky_maps_failed > 0:
        raise RuntimeError("{0} sky map{1} did not converge".format(
            count_sky_maps_failed, 's' if count_sky_maps_failed > 1 else ''))

ligo/skymap/tool/bayestar_localize_lvalert.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
r"""
Listen for new events from LVAlert and perform sky localization.

`bayestar-localize-lvalert` supports two modes of operation. You can
explicitly specify the GraceDb ID on the command line, as in::

    $ bayestar-localize-lvalert T90713

Or, `bayetar-localize-lvalert` can read GraceDB IDs from stdin (e.g., from the
terminal, or redirected from a fifo)::

    $ mkfifo /var/run/bayestar
    $ tail -F /var/run/bayestar | bayestar_localize_lvalert &
    $ echo T90713 > /var/run/bayestar
"""

from . import (
    ArgumentParser, EnableAction, waveform_parser, prior_parser, mcmc_parser,
    random_parser, iterlines)


def parser():
    parser = ArgumentParser(
        parents=[waveform_parser, prior_parser, mcmc_parser, random_parser])
    parser.add_argument(
        '-d', '--disable-detector', metavar='X1', type=str, nargs='+',
        help='disable certain detectors')
    parser.add_argument(
        '-N', '--dry-run', action='store_true',
        help='Dry run; do not update GraceDB entry')
    parser.add_argument(
        '--no-tag', action='store_true',
        help='Do not set lvem tag for GraceDB entry')
    parser.add_argument(
        '-o', '--output', metavar='FILE.fits[.gz]', default='bayestar.fits',
        help='Name for uploaded file')
    parser.add_argument(
        '--enable-multiresolution', action=EnableAction, default=True,
        help='generate a multiresolution HEALPix map')
    parser.add_argument(
        'graceid', metavar='G123456', nargs='*',
        help='Run on these GraceDB IDs. If no GraceDB IDs are listed on the '
        'command line, then read newline-separated GraceDB IDs from stdin.')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    import logging
    import os
    import re
    import sys
    import tempfile
    import urllib.parse
    from ..bayestar import localize, rasterize
    from ..io import fits
    from ..io import events
    from .. import omp
    from ..util.file import rename
    import ligo.gracedb.logging
    import ligo.gracedb.rest
    import numpy as np

    # Squelch annoying and uniformative LAL log messages.
    import lal
    lal.ClobberDebugLevel(lal.LALNDEBUG)

    log = logging.getLogger('BAYESTAR')

    log.info('Using %d OpenMP thread(s)', omp.num_threads)

    # If no GraceDB IDs were specified on the command line, then read them
    # from stdin line-by-line.
    graceids = opts.graceid if opts.graceid else iterlines(sys.stdin)

    # Fire up a GraceDb client
    # FIXME: Mimic the behavior of the GraceDb command line client, where the
    # environment variable GRACEDB_SERVICE_URL overrides the default service
    # URL. It would be nice to get this behavior into the gracedb package
    # itself.
    gracedb = ligo.gracedb.rest.GraceDb(
        os.environ.get(
            'GRACEDB_SERVICE_URL', ligo.gracedb.rest.DEFAULT_SERVICE_URL))

    # Determine the base URL for event pages.
    scheme, netloc, *_ = urllib.parse.urlparse(gracedb._service_url)
    base_url = urllib.parse.urlunparse((scheme, netloc, 'events', '', '', ''))

    if opts.chain_dump:
        chain_dump = re.sub(r'.fits(.gz)?$', r'.hdf5', opts.output)
    else:
        chain_dump = None

    tags = ("sky_loc",)
    if not opts.no_tag:
        tags += ("lvem",)

    event_source = events.gracedb.open(graceids, gracedb)

    if opts.disable_detector:
        event_source = events.detector_disabled.open(
            event_source, opts.disable_detector)

    for graceid in event_source.keys():

        try:
            event = event_source[graceid]
        except:  # noqa: E722
            log.exception('failed to read event %s from GraceDB', graceid)
            continue

        # Send log messages to GraceDb too
        if not opts.dry_run:
            handler = ligo.gracedb.logging.GraceDbLogHandler(gracedb, graceid)
            handler.setLevel(logging.INFO)
            logging.root.addHandler(handler)

        # A little bit of Cylon humor
        log.info('by your command...')

        try:
            # perform sky localization
            log.info("starting sky localization")
            sky_map = localize(
                event, opts.waveform, opts.f_low,
                np.deg2rad(opts.min_inclination),
                np.deg2rad(opts.max_inclination),
                opts.min_distance, opts.max_distance,
                opts.prior_distance_power, opts.cosmology,
                mcmc=opts.mcmc, chain_dump=chain_dump,
                enable_snr_series=opts.enable_snr_series,
                f_high_truncate=opts.f_high_truncate)
            if not opts.enable_multiresolution:
                sky_map = rasterize(sky_map)
            sky_map.meta['objid'] = str(graceid)
            sky_map.meta['url'] = '{}/{}'.format(base_url, graceid)
            log.info("sky localization complete")

            # upload FITS file
            with tempfile.TemporaryDirectory() as fitsdir:
                fitspath = os.path.join(fitsdir, opts.output)
                fits.write_sky_map(fitspath, sky_map, nest=True)
                log.debug('wrote FITS file: %s', opts.output)
                if opts.dry_run:
                    rename(fitspath, os.path.join('.', opts.output))
                else:
                    gracedb.writeLog(
                        graceid, "BAYESTAR rapid sky localization ready",
                        filename=fitspath, tagname=tags)
                log.debug('uploaded FITS file')
        except KeyboardInterrupt:
            # Produce log message and then exit if we receive SIGINT (ctrl-C).
            log.exception("sky localization failed")
            raise
        except:  # noqa: E722
            # Produce log message for any otherwise uncaught exception.
            # Unless we are in dry-run mode, keep going.
            log.exception("sky localization failed")
            if opts.dry_run:
                # Then re-raise the exception if we are in dry-run mode
                raise

        if not opts.dry_run:
            # Remove old log handler
            logging.root.removeHandler(handler)
            del handler

ligo/skymap/tool/bayestar_mcmc.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Markov-Chain Monte Carlo sky localization."""

from . import (
    ArgumentParser, FileType, waveform_parser,
    prior_parser, random_parser, mkpath)


def parser():
    parser = ArgumentParser(parents=[waveform_parser, prior_parser,
                                     random_parser])
    parser.add_argument(
        'input', metavar='INPUT.{hdf,xml,xml.gz,sqlite}', default='-',
        nargs='+', type=FileType('rb'),
        help='Input LIGO-LW XML file, SQLite file, or PyCBC HDF5 files. '
             'For PyCBC, you must supply the coincidence file '
             '(e.g. "H1L1-HDFINJFIND.hdf" or "H1L1-STATMAP.hdf"), '
             'the template bank file (e.g. H1L1-BANK2HDF.hdf), '
             'the single-detector merged PSD files '
             '(e.g. "H1-MERGE_PSDS.hdf" and "L1-MERGE_PSDS.hdf"), '
             'and the single-detector merged trigger files '
             '(e.g. "H1-HDF_TRIGGER_MERGE.hdf" and '
             '"L1-HDF_TRIGGER_MERGE.hdf"), '
             'in any order.')
    parser.add_argument(
        '--pycbc-sample', default='foreground',
        help='(PyCBC only) sample population')
    parser.add_argument(
        '--coinc-event-id', type=int, nargs='*',
        help='run on only these specified events')
    parser.add_argument(
        '--output', '-o', default='.',
        help='output directory')
    parser.add_argument(
        '--condor-submit', action='store_true',
        help='submit to Condor instead of running locally')

    group = parser.add_argument_group(
        'fixed parameter options',
        'Options to hold certain parameters constant')
    group.add_argument('--ra', type=float, metavar='DEG',
                       help='Right ascension')
    group.add_argument('--dec', type=float, metavar='DEG',
                       help='Declination')
    group.add_argument('--distance', type=float, metavar='Mpc',
                       help='Luminosity distance')

    return parser


def identity(x):
    return x


def main(args=None):
    opts = parser().parse_args(args)

    import logging
    log = logging.getLogger('BAYESTAR')

    # BAYESTAR imports.
    from ..io import events, hdf5
    from ..bayestar import condition, condition_prior, ez_emcee, log_post

    # Other imports.
    from astropy.table import Table
    import numpy as np
    import os
    from collections import OrderedDict
    import subprocess
    import sys

    # Squelch annoying and uniformative LAL log messages.
    import lal
    lal.ClobberDebugLevel(lal.LALNDEBUG)

    # Read coinc file.
    log.info(
        '%s:reading input files', ','.join(file.name for file in opts.input))
    event_source = events.open(*opts.input, sample=opts.pycbc_sample)

    mkpath(opts.output)

    if opts.condor_submit:
        if opts.seed is not None:
            raise NotImplementedError(
                '--seed does not yet work with --condor-submit')
        if opts.coinc_event_id:
            raise ValueError(
                'must not set --coinc-event-id with --condor-submit')
        with subprocess.Popen(['condor_submit'],
                              # FIXME: use text=True instead in Python >= 3.7
                              encoding=sys.stdin.encoding,
                              stdin=subprocess.PIPE) as proc:
            f = proc.stdin
            print('''
                  accounting_group = ligo.dev.o3.cbc.pe.bayestar
                  on_exit_remove = (ExitBySignal == False) && (ExitCode == 0)
                  on_exit_hold = (ExitBySignal == True) || (ExitCode != 0)
                  on_exit_hold_reason = (ExitBySignal == True \
                    ? strcat("The job exited with signal ", ExitSignal) \
                    : strcat("The job exited with code ", ExitCode))
                  request_memory = 1000 MB
                  universe = vanilla
                  getenv = true
                  executable = /usr/bin/env
                  JobBatchName = BAYESTAR
                  environment = "OMP_NUM_THREADS=1"
                  ''', file=f)
            print('error =', os.path.join(opts.output, '$(cid).err'), file=f)
            print('log =', os.path.join(opts.output, '$(cid).log'), file=f)
            print('arguments = "',
                  *(arg for arg in sys.argv if arg != '--condor-submit'),
                  '--coinc-event-id $(cid)"', file=f)
            print('queue cid in', *event_source, file=f)
        sys.exit(proc.returncode)

    if opts.coinc_event_id:
        event_source = OrderedDict(
            (key, event_source[key]) for key in opts.coinc_event_id)

    # Loop over all sngl_inspiral <-> sngl_inspiral coincs.
    for int_coinc_event_id, event in event_source.items():
        coinc_event_id = 'coinc_event:coinc_event_id:{}'.format(
            int_coinc_event_id)

        log.info('%s:preparing', coinc_event_id)

        epoch, sample_rate, epochs, snrs, responses, locations, horizons = \
            condition(event, waveform=opts.waveform, f_low=opts.f_low,
                      enable_snr_series=opts.enable_snr_series,
                      f_high_truncate=opts.f_high_truncate)

        min_distance, max_distance, prior_distance_power, cosmology = \
            condition_prior(horizons, opts.min_distance, opts.max_distance,
                            opts.prior_distance_power, opts.cosmology)

        gmst = lal.GreenwichMeanSiderealTime(epoch)

        max_abs_t = 2 * snrs.data.shape[1] / sample_rate
        xmin = [0, -1, min_distance, -1, 0, 0]
        xmax = [2 * np.pi, 1, max_distance, 1, 2 * np.pi, 2 * max_abs_t]
        names = 'ra dec distance inclination twopsi time'.split()
        transformed_names = 'ra sin_dec distance u twopsi time'.split()
        forward_transforms = [identity, np.sin, identity,
                              np.cos, identity, identity]
        reverse_transforms = [identity, np.arcsin, identity,
                              np.arccos, identity, identity]
        kwargs = dict(min_distance=min_distance, max_distance=max_distance,
                      prior_distance_power=prior_distance_power,
                      cosmology=cosmology, gmst=gmst, sample_rate=sample_rate,
                      epochs=epochs, snrs=snrs, responses=responses,
                      locations=locations, horizons=horizons)

        # Fix parameters
        for i, key in reversed(list(enumerate(['ra', 'dec', 'distance']))):
            value = getattr(opts, key)
            if value is None:
                continue

            if key in ['ra', 'dec']:
                # FIXME: figure out a more elegant way to address different
                # units in command line arguments and posterior samples
                value = np.deg2rad(value)

            kwargs[transformed_names[i]] = forward_transforms[i](value)
            del (xmin[i], xmax[i], names[i], transformed_names[i],
                 forward_transforms[i], reverse_transforms[i])

        log.info('%s:sampling', coinc_event_id)

        # Run MCMC
        chain = ez_emcee(log_post, xmin, xmax, kwargs=kwargs, vectorize=True)

        # Transform back from sin_dec to dec and cos_inclination to inclination
        for i, func in enumerate(reverse_transforms):
            chain[:, i] = func(chain[:, i])

        # Create Astropy table
        chain = Table(rows=chain, names=names, copy=False)

        log.info('%s:saving posterior samples', coinc_event_id)

        hdf5.write_samples(
            chain,
            os.path.join(opts.output, '{}.hdf5'.format(int_coinc_event_id)),
            path='/bayestar/posterior_samples', overwrite=True)

ligo/skymap/tool/bayestar_realize_coincs.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
252  
253  
254  
255  
256  
257  
258  
259  
260  
261  
262  
263  
264  
265  
266  
267  
268  
269  
270  
271  
272  
273  
274  
275  
276  
277  
278  
279  
280  
281  
282  
283  
284  
285  
286  
287  
288  
289  
290  
291  
292  
293  
294  
295  
296  
297  
298  
299  
300  
301  
302  
303  
304  
305  
306  
307  
308  
309  
310  
311  
312  
313  
314  
315  
316  
317  
318  
319  
320  
321  
322  
323  
324  
325  
326  
327  
328  
329  
330  
331  
332  
333  
334  
335  
336  
337  
338  
339  
340  
341  
342  
343  
344  
345  
346  
347  
348  
349  
350  
351  
352  
353  
354  
355  
356  
357  
358  
359  
360  
361  
362  
363  
364  
365  
366  
367  
368  
369  
370  
371  
372  
373  
374  
375  
376  
377  
378  
379  
380  
381  
382  
383  
384  
385  
386  
387  
388  
389  
390  
391  
392  
393  
394  
395  
396  
397  
398  
399  
400  
401  
402  
403  
404  
405  
406  
407  
408  
409  
410  
411  
412  
413  
414  
415  
416  
417  
418  
419  
420  
421  
422  
423  
424  
425  
426  
427  
428  
429  
430  
431  
432  
433  
434  
435  
436  
437  
438  
439  
440  
441  
442  
443  
444  
445  
446  
447  
448  
449  
450  
451  
452  
453  
454  
455  
456  
457  
458  
459  
460  
461  
462  
463  
464  
465  
466  
467  
468  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Synthesize triggers for simulated sources by realizing Gaussian measurement
errors in SNR and time of arrival. The input file (or stdin if the input file
is omitted) should be an optionally gzip-compressed LIGO-LW XML file of the
form produced by lalapps_inspinj. The output file (or stdout if omitted) will
be an optionally gzip-compressed LIGO-LW XML file containing single-detector
triggers and coincidences.

The root-mean square measurement error depends on the SNR of the signal, so
there is a choice for how to generate perturbed time and phase measurements:

 - `zero-noise`: no measurement error at all
 - `gaussian-noise`: measurement error for a matched filter in Gaussian noise
"""

import copy
import functools

import lal
import numpy as np

from . import (
    ArgumentParser, EnableAction, FileType, random_parser, register_to_xmldoc,
    write_fileobj)

# Squelch annoying and uninformative LAL log messages.
lal.ClobberDebugLevel(lal.LALNDEBUG)


def parser():
    # Determine list of known detectors for command line arguments.
    available_ifos = sorted(det.frDetector.prefix
                            for det in lal.CachedDetectors)

    parser = ArgumentParser(parents=[random_parser])
    parser.add_argument(
        'input', metavar='IN.xml[.gz]', type=FileType('rb'),
        default='-', help='Name of input file')
    parser.add_argument(
        '-o', '--output', metavar='OUT.xml[.gz]', type=FileType('wb'),
        default='-', help='Name of output file')
    parser.add_argument(
        '-j', '--jobs', type=int, default=1, const=None, nargs='?',
        help='Number of threads')
    parser.add_argument(
        '--detector', metavar='|'.join(available_ifos), nargs='+',
        help='Detectors to use', choices=available_ifos, required=True)
    parser.add_argument(
        '--waveform',
        help='Waveform to use for injections (overrides values in '
        'sim_inspiral table)')
    parser.add_argument(
        '--snr-threshold', type=float, default=4.,
        help='Single-detector SNR threshold')
    parser.add_argument(
        '--net-snr-threshold', type=float, default=12.,
        help='Network SNR threshold')
    parser.add_argument(
        '--keep-subthreshold', action='store_true',
        help='Keep sub-threshold triggers that do not contribute to the '
        'network SNR')
    parser.add_argument(
        '--min-triggers', type=int, default=2,
        help='Emit coincidences only when at least this many triggers '
        'are found')
    parser.add_argument(
        '--min-distance', type=float, default=0.0,
        help='Skip events with distance less than this value')
    parser.add_argument(
        '--max-distance', type=float, default=float('inf'),
        help='Skip events with distance greater than this value')
    parser.add_argument(
        '--measurement-error', default='zero-noise',
        choices=('zero-noise', 'gaussian-noise'),
        help='How to compute the measurement error')
    parser.add_argument(
        '--enable-snr-series', action=EnableAction,
        help='Enable output of SNR time series')
    parser.add_argument(
        '--reference-psd', metavar='PSD.xml[.gz]', type=FileType('rb'),
        required=True, help='Name of PSD file')
    parser.add_argument(
        '--f-low', type=float,
        help='Override low frequency cutoff found in sim_inspiral table')
    parser.add_argument(
        '--f-high', type=float,
        help='Set high frequency cutoff to simulate early warning')
    parser.add_argument(
        '--duty-cycle', type=float, default=1.0,
        help='Single-detector duty cycle')
    parser.add_argument(
        '-P', '--preserve-ids', action='store_true',
        help='Preserve original simulation_ids')
    return parser


def simulate_snr(ra, dec, psi, inc, distance, epoch, gmst, H, S,
                 response, location, measurement_error='zero-noise',
                 duration=0.1):
    from scipy.interpolate import interp1d

    from ..bayestar import filter
    from ..bayestar.interpolation import interpolate_max

    # Calculate whitened template autocorrelation sequence.
    HS = filter.signal_psd_series(H, S)
    n = len(HS.data.data)
    acor, sample_rate = filter.autocorrelation(HS, duration)

    # Calculate time, amplitude, and phase.
    u = np.cos(inc)
    u2 = np.square(u)
    signal_model = filter.SignalModel(HS)
    horizon = signal_model.get_horizon_distance()
    Fplus, Fcross = lal.ComputeDetAMResponse(response, ra, dec, psi, gmst)
    toa = lal.TimeDelayFromEarthCenter(location, ra, dec, epoch)
    z = (0.5 * (1 + u2) * Fplus + 1j * u * Fcross) * horizon / distance

    # Calculate complex autocorrelation sequence.
    snr_series = z * np.concatenate((acor[:0:-1].conj(), acor))

    # If requested, add noise.
    if measurement_error == 'gaussian-noise':
        sigmasq = 4 * np.sum(HS.deltaF * np.abs(HS.data.data))
        amp = 4 * n * HS.deltaF**0.5 * np.sqrt(HS.data.data / sigmasq)
        N = lal.CreateCOMPLEX16FrequencySeries(
            '', HS.epoch, HS.f0, HS.deltaF, HS.sampleUnits, n)
        N.data.data = amp * (
            np.random.randn(n) + 1j * np.random.randn(n))
        noise_term, sample_rate_2 = filter.autocorrelation(
            N, len(snr_series) / sample_rate, normalize=False)
        assert sample_rate == sample_rate_2
        snr_series += noise_term

    # Shift SNR series to the nearest sample.
    int_samples, frac_samples = divmod(
        (1e-9 * epoch.gpsNanoSeconds + toa) * sample_rate, 1)
    if frac_samples > 0.5:
        int_samples += 1
        frac_samples -= 1
    epoch = lal.LIGOTimeGPS(epoch.gpsSeconds, 0)
    n = len(acor) - 1
    mprime = np.arange(-n, n + 1)
    m = mprime + frac_samples
    re, im = (
        interp1d(m, x, kind='cubic', bounds_error=False, fill_value=0)(mprime)
        for x in (snr_series.real, snr_series.imag))
    snr_series = re + 1j * im

    # Find the trigger values.
    i_nearest = np.argmax(np.abs(snr_series[n-n//2:n+n//2+1])) + n-n//2
    i_interp, z_interp = interpolate_max(i_nearest, snr_series,
                                         n // 2, method='lanczos')
    toa = epoch + (int_samples + i_interp - n) / sample_rate
    snr = np.abs(z_interp)
    phase = np.angle(z_interp)

    # Shift and truncate the SNR time series.
    epoch += (int_samples + i_nearest - n - n // 2) / sample_rate
    snr_series = snr_series[(i_nearest - n // 2):(i_nearest + n // 2 + 1)]
    tseries = lal.CreateCOMPLEX8TimeSeries(
        'snr', epoch, 0, 1 / sample_rate,
        lal.DimensionlessUnit, len(snr_series))
    tseries.data.data = snr_series
    return horizon, snr, phase, toa, tseries


def simulate(seed, sim_inspiral, psds, responses, locations, measurement_error,
             f_low=None, f_high=None, waveform=None):
    from ..bayestar import filter

    np.random.seed(seed)

    # Unpack some values from the row in the table.
    DL = sim_inspiral.distance
    ra = sim_inspiral.longitude
    dec = sim_inspiral.latitude
    inc = sim_inspiral.inclination
    # phi = sim_inspiral.coa_phase  # arbitrary?
    psi = sim_inspiral.polarization
    epoch = sim_inspiral.time_geocent
    gmst = lal.GreenwichMeanSiderealTime(epoch)

    f_low = f_low or sim_inspiral.f_lower
    waveform = waveform or sim_inspiral.waveform

    # Signal models for each detector.
    H = filter.sngl_inspiral_psd(
        waveform,
        mass1=sim_inspiral.mass1,
        mass2=sim_inspiral.mass2,
        spin1x=sim_inspiral.spin1x,
        spin1y=sim_inspiral.spin1y,
        spin1z=sim_inspiral.spin1z,
        spin2x=sim_inspiral.spin2x,
        spin2y=sim_inspiral.spin2y,
        spin2z=sim_inspiral.spin2z,
        f_min=f_low,
        f_final=f_high)

    return [
        simulate_snr(
            ra, dec, psi, inc, DL, epoch, gmst, H, S, response, location,
            measurement_error=measurement_error)
        for S, response, location in zip(psds, responses, locations)]


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    # LIGO-LW XML imports.
    from ligo.lw import ligolw
    from ligo.lw.param import Param
    from ligo.lw.utils import process as ligolw_process
    from ligo.lw.utils.search_summary import append_search_summary
    from ligo.lw import utils as ligolw_utils
    from ligo.lw.lsctables import (
        New, CoincDefTable, CoincID, CoincInspiralTable, CoincMapTable,
        CoincTable, ProcessParamsTable, ProcessTable, SimInspiralTable,
        SnglInspiralTable, TimeSlideTable)

    # glue, LAL and pylal imports.
    from ligo import segments
    import lal
    import lal.series
    import lalsimulation
    from lalinspiral.inspinjfind import InspiralSCExactCoincDef
    from lalinspiral.thinca import InspiralCoincDef
    from tqdm import tqdm

    # BAYESTAR imports.
    from ..io.events.ligolw import ContentHandler
    from ..bayestar import filter
    from ..util.progress import progress_map

    # Read PSDs.
    xmldoc = ligolw_utils.load_fileobj(
        opts.reference_psd, contenthandler=lal.series.PSDContentHandler)
    psds = lal.series.read_psd_xmldoc(xmldoc, root_name=None)
    psds = {
        key: filter.InterpolatedPSD(filter.abscissa(psd), psd.data.data)
        for key, psd in psds.items() if psd is not None}
    psds = [psds[ifo] for ifo in opts.detector]

    # Extract simulation table from injection file.
    inj_xmldoc = ligolw_utils.load_fileobj(
        opts.input, contenthandler=ContentHandler)
    orig_sim_inspiral_table = SimInspiralTable.get_table(inj_xmldoc)

    # Prune injections that are outside distance limits.
    orig_sim_inspiral_table[:] = [
        row for row in orig_sim_inspiral_table
        if opts.min_distance <= row.distance <= opts.max_distance]

    # Open output file.
    xmldoc = ligolw.Document()
    xmlroot = xmldoc.appendChild(ligolw.LIGO_LW())

    # Create tables. Process and ProcessParams tables are copied from the
    # injection file.
    coinc_def_table = xmlroot.appendChild(New(CoincDefTable))
    coinc_inspiral_table = xmlroot.appendChild(New(CoincInspiralTable))
    coinc_map_table = xmlroot.appendChild(New(CoincMapTable))
    coinc_table = xmlroot.appendChild(New(CoincTable))
    xmlroot.appendChild(ProcessParamsTable.get_table(inj_xmldoc))
    xmlroot.appendChild(ProcessTable.get_table(inj_xmldoc))
    sim_inspiral_table = xmlroot.appendChild(New(SimInspiralTable))
    sngl_inspiral_table = xmlroot.appendChild(New(SnglInspiralTable))
    time_slide_table = xmlroot.appendChild(New(TimeSlideTable))

    # Write process metadata to output file.
    process = register_to_xmldoc(
        xmldoc, p, opts, ifos=opts.detector, comment="Simulated coincidences")

    # Add search summary to output file.
    all_time = segments.segment([lal.LIGOTimeGPS(0), lal.LIGOTimeGPS(2e9)])
    append_search_summary(xmldoc, process, inseg=all_time, outseg=all_time)

    # Create a time slide entry.  Needed for coinc_event rows.
    time_slide_id = time_slide_table.get_time_slide_id(
        {ifo: 0 for ifo in opts.detector}, create_new=process)

    # Populate CoincDef table.
    inspiral_coinc_def = copy.copy(InspiralCoincDef)
    inspiral_coinc_def.coinc_def_id = coinc_def_table.get_next_id()
    coinc_def_table.append(inspiral_coinc_def)
    found_coinc_def = copy.copy(InspiralSCExactCoincDef)
    found_coinc_def.coinc_def_id = coinc_def_table.get_next_id()
    coinc_def_table.append(found_coinc_def)

    # Precompute values that are common to all simulations.
    detectors = [lalsimulation.DetectorPrefixToLALDetector(ifo)
                 for ifo in opts.detector]
    responses = [det.response for det in detectors]
    locations = [det.location for det in detectors]

    if opts.jobs != 1:
        from .. import omp
        omp.num_threads = 1  # disable OpenMP parallelism

    func = functools.partial(simulate, psds=psds,
                             responses=responses, locations=locations,
                             measurement_error=opts.measurement_error,
                             f_low=opts.f_low, f_high=opts.f_high,
                             waveform=opts.waveform)

    # Make sure that each thread gets a different random number state.
    # We start by drawing a random integer s in the main thread, and
    # then the i'th subprocess will seed itself with the integer i + s.
    #
    # The seed must be an unsigned 32-bit integer, so if there are n
    # threads, then s must be drawn from the interval [0, 2**32 - n).
    #
    # Note that *we* are thread 0, so there are a total of
    # n=1+len(sim_inspiral_table) threads.
    seed = np.random.randint(0, 2 ** 32 - len(sim_inspiral_table) - 1)
    np.random.seed(seed)

    with tqdm(desc='accepted') as progress:
        for sim_inspiral, simulation in zip(
                orig_sim_inspiral_table,
                progress_map(
                    func,
                    np.arange(len(orig_sim_inspiral_table)) + seed + 1,
                    orig_sim_inspiral_table, jobs=opts.jobs)):

            sngl_inspirals = []
            used_snr_series = []
            net_snr = 0.0
            count_triggers = 0

            # Loop over individual detectors and create SnglInspiral entries.
            for ifo, (horizon, abs_snr, arg_snr, toa, series) \
                    in zip(opts.detector, simulation):

                if np.random.uniform() > opts.duty_cycle:
                    continue
                elif abs_snr >= opts.snr_threshold:
                    # If SNR < threshold, then the injection is not found.
                    # Skip it.
                    count_triggers += 1
                    net_snr += np.square(abs_snr)
                elif not opts.keep_subthreshold:
                    continue

                # Create SnglInspiral entry.
                used_snr_series.append(series)
                sngl_inspirals.append(
                    sngl_inspiral_table.RowType(**dict(
                        dict.fromkeys(sngl_inspiral_table.validcolumns, None),
                        process_id=process.process_id,
                        ifo=ifo,
                        mass1=sim_inspiral.mass1,
                        mass2=sim_inspiral.mass2,
                        spin1x=sim_inspiral.spin1x,
                        spin1y=sim_inspiral.spin1y,
                        spin1z=sim_inspiral.spin1z,
                        spin2x=sim_inspiral.spin2x,
                        spin2y=sim_inspiral.spin2y,
                        spin2z=sim_inspiral.spin2z,
                        end=toa,
                        snr=abs_snr,
                        coa_phase=arg_snr,
                        f_final=opts.f_high,
                        eff_distance=horizon / abs_snr)))

            net_snr = np.sqrt(net_snr)

            # If too few triggers were found, then skip this event.
            if count_triggers < opts.min_triggers:
                continue

            # If network SNR < threshold, then the injection is not found.
            # Skip it.
            if net_snr < opts.net_snr_threshold:
                continue

            # Add Coinc table entry.
            coinc = coinc_table.appendRow(
                coinc_event_id=coinc_table.get_next_id(),
                process_id=process.process_id,
                coinc_def_id=inspiral_coinc_def.coinc_def_id,
                time_slide_id=time_slide_id,
                insts=opts.detector,
                nevents=len(opts.detector),
                likelihood=None)

            # Add CoincInspiral table entry.
            coinc_inspiral_table.appendRow(
                coinc_event_id=coinc.coinc_event_id,
                instruments=[
                    sngl_inspiral.ifo for sngl_inspiral in sngl_inspirals],
                end=lal.LIGOTimeGPS(1e-9 * np.mean([
                    sngl_inspiral.end.ns()
                    for sngl_inspiral in sngl_inspirals
                    if sngl_inspiral.end is not None])),
                mass=sim_inspiral.mass1 + sim_inspiral.mass2,
                mchirp=sim_inspiral.mchirp,
                combined_far=0.0,  # Not provided
                false_alarm_rate=0.0,  # Not provided
                minimum_duration=None,  # Not provided
                snr=net_snr)

            # Record all sngl_inspiral records and associate them with coincs.
            for sngl_inspiral, series in zip(sngl_inspirals, used_snr_series):
                # Give this sngl_inspiral record an id and add it to the table.
                sngl_inspiral.event_id = sngl_inspiral_table.get_next_id()
                sngl_inspiral_table.append(sngl_inspiral)

                if opts.enable_snr_series:
                    elem = lal.series.build_COMPLEX8TimeSeries(series)
                    elem.appendChild(
                        Param.from_pyvalue('event_id', sngl_inspiral.event_id))
                    xmlroot.appendChild(elem)

                # Add CoincMap entry.
                coinc_map_table.appendRow(
                    coinc_event_id=coinc.coinc_event_id,
                    table_name=sngl_inspiral_table.tableName,
                    event_id=sngl_inspiral.event_id)

            # Record injection
            if not opts.preserve_ids:
                sim_inspiral.simulation_id = sim_inspiral_table.get_next_id()
            sim_inspiral_table.append(sim_inspiral)

            progress.update()

    # Record coincidence associating injections with events.
    for i, sim_inspiral in enumerate(sim_inspiral_table):
        coinc = coinc_table.appendRow(
            coinc_event_id=coinc_table.get_next_id(),
            process_id=process.process_id,
            coinc_def_id=found_coinc_def.coinc_def_id,
            time_slide_id=time_slide_id,
            instruments=None,
            nevents=None,
            likelihood=None)
        coinc_map_table.appendRow(
            coinc_event_id=coinc.coinc_event_id,
            table_name=sim_inspiral_table.tableName,
            event_id=sim_inspiral.simulation_id)
        coinc_map_table.appendRow(
            coinc_event_id=coinc.coinc_event_id,
            table_name=coinc_table.tableName,
            event_id=CoincID(i))

    # Record process end time.
    ligolw_process.set_process_end_time(process)

    # Write output file.
    write_fileobj(xmldoc, opts.output)

ligo/skymap/tool/bayestar_sample_model_psd.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
#
# Copyright (C) 2014-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Construct a LIGO-LW XML power spectral density file for a network of
detectors by evaluating a model power noise sensitivity curve.
"""

from argparse import SUPPRESS
import inspect

from . import ArgumentParser, FileType, register_to_xmldoc, write_fileobj

psd_name_prefix = 'SimNoisePSD'


def parser():
    import lal
    import lalsimulation

    # Get names of PSD functions.
    psd_names = sorted(
        name[len(psd_name_prefix):]
        for name, func in inspect.getmembers(lalsimulation)
        if name.startswith(psd_name_prefix) and callable(func) and (
            '(double f) -> double' in func.__doc__ or
            '(REAL8FrequencySeries psd, double flow) -> int' in func.__doc__))

    parser = ArgumentParser()
    parser.add_argument(
        '-o', '--output', metavar='OUT.xml[.gz]', type=FileType('wb'),
        default='-', help='Name of output file [default: stdout]')
    parser.add_argument(
        '--df', metavar='Hz', type=float, default=1.0,
        help='Frequency step size [default: %(default)s]')
    parser.add_argument(
        '--f-max', metavar='Hz', type=float, default=2048.0,
        help='Maximum frequency [default: %(default)s]')

    detector_group = parser.add_argument_group(
        'detector noise curves',
        'Options to select noise curves for detectors.\n\n'
        'All detectors support the following options:\n\n' +
        '\n'.join(psd_names))

    scale_group = parser.add_argument_group(
        'detector scaling',
        'Options to apply scale factors to noise curves for detectors.\n'
        'For example, a scale factor of 2 means that the amplitude spectral\n'
        'density is multiplied by 1/2 so that the range is multiplied by a 2.')

    # Add options for individual detectors
    for detector in lal.CachedDetectors:
        name = detector.frDetector.name
        prefix = detector.frDetector.prefix
        detector_group.add_argument(
            '--' + prefix, choices=psd_names,
            metavar='func', default=SUPPRESS,
            help='PSD function for {0} detector'.format(name))
        scale_group.add_argument(
            '--' + prefix + '-scale', type=float, default=SUPPRESS,
            help='Scale range for {0} detector'.format(name))

    return parser


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    import lal.series
    import lalsimulation
    import numpy as np
    from ..bayestar.filter import vectorize_swig_psd_func

    # Add basic options.

    psds = {}

    n = int(opts.f_max // opts.df)
    f = np.arange(n) * opts.df

    detectors = [d.frDetector.prefix for d in lal.CachedDetectors]

    for detector in detectors:
        psd_name = getattr(opts, detector, None)
        if psd_name is None:
            continue
        scale = 1 / np.square(getattr(opts, detector + '_scale', 1.0))
        func = getattr(lalsimulation, psd_name_prefix + psd_name)
        series = lal.CreateREAL8FrequencySeries(
            psd_name, 0, 0, opts.df, lal.SecondUnit, n)
        if '(double f) -> double' in func.__doc__:
            series.data.data = vectorize_swig_psd_func(
                psd_name_prefix + psd_name)(f)
        else:
            func(series, 0.0)

            # Find indices of first and last nonzero samples.
            nonzero = np.flatnonzero(series.data.data)
            # FIXME: int cast seems to be needed on old versions of Numpy
            first_nonzero = int(nonzero[0])
            last_nonzero = int(nonzero[-1])

            # Truncate
            series = lal.CutREAL8FrequencySeries(
                series, first_nonzero, last_nonzero - first_nonzero + 1)
            series.f0 = first_nonzero * series.deltaF

            series.name = psd_name
        series.data.data *= scale
        psds[detector] = series

    xmldoc = lal.series.make_psd_xmldoc(psds)
    register_to_xmldoc(xmldoc, p, opts)
    write_fileobj(xmldoc, opts.output)

ligo/skymap/tool/ligo_skymap_combine.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
#
# Copyright (C) 2018-2020  Tito Dal Canton, Eric Burns, Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Combine different sky localizations of a common event observed by different
instruments in order to form a more constrained localization.

If one of the input maps contains distance information (for instance from
BAYESTAR or LALInference) then the marginal distance posterior in the output
map is updated according to the restriction in sky location imposed by the
other input map(s). Only one input map can currently have distance
information.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument('input', metavar='INPUT.fits[.gz]',
                        type=FileType('rb'), nargs='+',
                        help='Input sky localizations')
    # FIXME the output option has type str because astropy.io.fits.writeto()
    # only honors the .gz extension when given a file name string (as of 3.0.1)
    parser.add_argument('output', metavar='OUTPUT.fits[.gz]', type=str,
                        help='Output combined sky localization')
    parser.add_argument('--origin', type=str,
                        help='Optional tag describing the organization'
                             ' responsible for the combined output')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    import numpy as np
    import astropy_healpix as ah
    from astropy.io import fits
    from astropy.time import Time
    import healpy as hp

    from ..distance import parameters_to_marginal_moments
    from ..io import read_sky_map, write_sky_map

    input_skymaps = []
    dist_mu = dist_sigma = dist_norm = None
    for input_file in args.input:
        with fits.open(input_file) as hdus:
            header = hdus[0].header.copy()
            header.extend(hdus[1].header)
            has_distance = 'DISTMU' in hdus[1].columns.names
            data, meta = read_sky_map(hdus, nest=True,
                                      distances=has_distance)

        if has_distance:
            if dist_mu is not None:
                raise RuntimeError('only one input localization can have'
                                   ' distance information')
            dist_mu = data[1]
            dist_sigma = data[2]
            dist_norm = data[3]
        else:
            data = (data,)

        nside = ah.npix_to_nside(len(data[0]))
        input_skymaps.append((nside, data[0], meta, header))

    max_nside = max(x[0] for x in input_skymaps)

    # upsample sky posteriors to maximum resolution and combine them
    combined_prob = None
    for nside, prob, _, _ in input_skymaps:
        if nside < max_nside:
            prob = hp.ud_grade(prob, max_nside, order_in='NESTED',
                               order_out='NESTED')
        if combined_prob is None:
            combined_prob = np.ones_like(prob)
        combined_prob *= prob

    # normalize joint posterior
    norm = combined_prob.sum()
    if norm == 0:
        raise RuntimeError('input sky localizations are disjoint')
    combined_prob /= norm

    out_kwargs = {'gps_creation_time': Time.now().gps,
                  'nest': True}
    if args.origin is not None:
        out_kwargs['origin'] = args.origin

    # average the various input event times
    input_gps = [x[2]['gps_time'] for x in input_skymaps if 'gps_time' in x[2]]
    if input_gps:
        out_kwargs['gps_time'] = np.mean(input_gps)

    # combine instrument tags
    out_instruments = set()
    for x in input_skymaps:
        if 'instruments' in x[2]:
            out_instruments.update(x[2]['instruments'])
    out_kwargs['instruments'] = ','.join(out_instruments)

    # update marginal distance posterior, if available
    if dist_mu is not None:
        if ah.npix_to_nside(len(dist_mu)) < max_nside:
            dist_mu = hp.ud_grade(dist_mu, max_nside, order_in='NESTED',
                                  order_out='NESTED')
            dist_sigma = hp.ud_grade(dist_sigma, max_nside, order_in='NESTED',
                                     order_out='NESTED')
            dist_norm = hp.ud_grade(dist_norm, max_nside, order_in='NESTED',
                                    order_out='NESTED')
        distmean, diststd = parameters_to_marginal_moments(combined_prob,
                                                           dist_mu,
                                                           dist_sigma)
        out_data = (combined_prob, dist_mu, dist_sigma, dist_norm)
        out_kwargs['distmean'] = distmean
        out_kwargs['diststd'] = diststd
    else:
        out_data = combined_prob

    # save input headers in output history
    out_kwargs['HISTORY'] = []
    for i, x in enumerate(input_skymaps):
        out_kwargs['HISTORY'].append('')
        out_kwargs['HISTORY'].append(
            'Headers of HDUs 0 and 1 of input file {:d}:'.format(i))
        out_kwargs['HISTORY'].append('')
        out_kwargs['HISTORY'] += [
            '{} = {}'.format(k, v) for k, v in x[3].items()]

    write_sky_map(args.output, out_data, **out_kwargs)

ligo/skymap/tool/ligo_skymap_constellations.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
List most likely constellations for a localization.

Just for fun, and for public outreach purposes.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '-o', '--output', metavar='OUT.dat', type=FileType('w'), default='-',
        help='Name of output file')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Late imports

    from ..io import fits
    import astropy_healpix as ah
    from astropy.coordinates import SkyCoord
    from astropy.table import Table
    from astropy import units as u
    import healpy as hp
    import numpy as np

    prob, meta = fits.read_sky_map(opts.input.name, nest=None)
    npix = len(prob)
    nside = ah.npix_to_nside(npix)
    ipix = np.arange(npix)
    ra, dec = hp.pix2ang(nside, ipix, lonlat=True, nest=meta['nest'])
    coord = SkyCoord(ra * u.deg, dec * u.deg)
    table = Table({'prob': prob, 'constellation': coord.get_constellation()},
                  copy=False)
    table = table.group_by('constellation').groups.aggregate(np.sum)
    table.sort('prob')
    table.reverse()
    table.write(opts.output, format='ascii.tab')

ligo/skymap/tool/ligo_skymap_contour.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
#
# Copyright (C) 2015-2018  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
Create a contours for the credible levels of an all-sky probability map.
The input is a HEALPix probability map.
The output is a GeoJSON FeatureCollection (http://geojson.org/).
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument(
        '-o', '--output', metavar='FILE.geojson',
        default='-', type=FileType('w'), help='output file [default: stdout]')
    parser.add_argument(
        '--contour', metavar='PERCENT', type=float, nargs='+', required=True,
        help='plot contour enclosing this percentage of probability mass')
    parser.add_argument(
        '-i', '--interpolate',
        choices='nearest nested bilinear'.split(), default='nearest',
        help='resampling interpolation method')
    parser.add_argument(
        '-s', '--simplify', action='store_true', help='simplify contour paths')
    parser.add_argument(
        '-n', '--nside', metavar='NSIDE', type=int,
        help='optionally resample to the specified resolution '
        ' before generating contours')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    import healpy as hp
    import numpy as np
    import json

    from ..io import fits
    from .. import postprocess

    # Read input file
    prob, _ = fits.read_sky_map(opts.input.name, nest=True)

    # Resample if requested
    if opts.nside is not None and opts.interpolate in ('nearest', 'nested'):
        prob = hp.ud_grade(prob, opts.nside, order_in='NESTED', power=-2)
    elif opts.nside is not None and opts.interpolate == 'bilinear':
        prob = postprocess.smooth_ud_grade(prob, opts.nside, nest=True)
    if opts.interpolate == 'nested':
        prob = postprocess.interpolate_nested(prob, nest=True)

    # Find credible levels
    i = np.flipud(np.argsort(prob))
    cumsum = np.cumsum(prob[i])
    cls = np.empty_like(prob)
    cls[i] = cumsum * 100

    # Generate contours
    paths = list(postprocess.contour(
        cls, opts.contour, nest=True, degrees=True, simplify=opts.simplify))

    json.dump({
        'type': 'FeatureCollection',
        'features': [
            {
                'type': 'Feature',
                'properties': {
                    'credible_level': contour
                },
                'geometry': {
                    'type': 'MultiLineString',
                    'coordinates': path
                }
            }
            for contour, path in zip(opts.contour, paths)
        ]
    }, opts.output)

ligo/skymap/tool/ligo_skymap_contour_moc.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
#
# Copyright (C) 2013-2020 Giuseppe Greco, Leo Singer, and CDS team.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
Create a contour for a credible level of an all-sky probability map. The input
is a HEALPix FITS probability map. The output is a `Multi-Order Coverage (MOC)
<http://ivoa.net/documents/MOC/>`_ FITS file.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()

    parser.add_argument(
        '--output', metavar='FILE.fits',
        default='-', type=str, help='output file [default: stdout]')
    parser.add_argument(
        '-c', '--contour', metavar='PERCENT', type=float, required=True,
        help='MOC region enclosing this percentage of probability \
              [range is 0-100]')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input multi-order or flatten \
                                      HEALPix FITS file')

    return parser


def main(args=None):
    p = parser()
    opts = parser().parse_args(args)

    import astropy_healpix as ah
    import astropy.units as u

    try:
        from mocpy import MOC
    except ImportError:
        p.error('This command-line tool requires mocpy >= 0.8.2. '
                'Please install it by running "pip install mocpy".')

    from ..io import read_sky_map

    # Read multi-order sky map
    skymap = read_sky_map(opts.input.name, moc=True)

    uniq = skymap['UNIQ']
    probdensity = skymap['PROBDENSITY']

    level, ipix = ah.uniq_to_level_ipix(uniq)
    area = ah.nside_to_pixel_area(
        ah.level_to_nside(level)).to_value(u.steradian)

    prob = probdensity * area

    # Create MOC
    contour_decimal = opts.contour / 100
    moc = MOC.from_valued_healpix_cells(
        uniq, prob, cumul_from=0.0, cumul_to=contour_decimal)

    # Write MOC
    moc.write(opts.output, format='fits', overwrite=True)

ligo/skymap/tool/ligo_skymap_flatten.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Convert a HEALPix FITS file from multi-resolution UNIQ indexing to the more
common IMPLICIT indexing.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument('input', metavar='INPUT.fits[.gz]',
                        type=FileType('rb'), help='Input FITS file')
    parser.add_argument('output', metavar='OUTPUT.fits',
                        type=FileType('wb'), help='Output FITS file')
    parser.add_argument('--nside', type=int, help='Output HEALPix resolution')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    import logging
    import warnings
    import astropy_healpix as ah
    from astropy.io import fits
    from ..io import read_sky_map, write_sky_map
    from ..bayestar import rasterize

    log = logging.getLogger()

    if args.nside is None:
        order = None
    else:
        order = ah.nside_to_level(args.nside)

    log.info('reading FITS file %s', args.input.name)
    hdus = fits.open(args.input)
    ordering = hdus[1].header['ORDERING']
    expected_ordering = 'NUNIQ'
    if ordering != expected_ordering:
        msg = 'Expected the FITS file {} to have ordering {}, but it is {}'
        warnings.warn(msg.format(args.input.name, expected_ordering, ordering))
    log.debug('converting original FITS file to Astropy table')
    table = read_sky_map(hdus, moc=True)
    log.debug('flattening HEALPix tree')
    table = rasterize(table, order=order)
    log.info('writing FITS file %s', args.output.name)
    write_sky_map(args.output.name, table, nest=True)
    log.debug('done')

ligo/skymap/tool/ligo_skymap_from_samples.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
#
# Copyright (C) 2011-2020  Will M. Farr <will.farr@ligo.org>
#                          Leo P. Singer <leo.singer@ligo.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Generate a FITS sky map file from posterior samples using clustering and
kernel density estimation.

The output consist of two files:

*  ``skypost.obj``, a :mod:`pickle` representation of the kernel density
   estimator
*  ``skymap.fits.gz``, a 3D localization in HEALPix/FITS format
"""

from argparse import SUPPRESS

from . import ArgumentParser, DirType, EnableAction, FileType, random_parser


def parser():
    # Command line interface.
    parser = ArgumentParser(parents=[random_parser])
    parser.add_argument('samples', type=FileType('rb'), metavar='SAMPLES.hdf5',
                        help='posterior samples file')
    # Only present for backward compatibility with --samples syntax
    parser.add_argument('--samples', action='store_false', dest='_ignored',
                        help=SUPPRESS)
    parser.add_argument('--outdir', '-o', default='.',
                        type=DirType(create=True), help='output directory')
    parser.add_argument('--fitsoutname', default='skymap.fits',
                        metavar='SKYMAP.fits[.gz]',
                        help='filename for the FITS file')
    parser.add_argument('--loadpost', type=FileType('rb'),
                        metavar='SKYPOST.obj',
                        help='filename for pickled posterior state')
    parser.add_argument('--maxpts', type=int,
                        help='maximum number of posterior points to use; if '
                        'omitted or greater than or equal to the number of '
                        'posterior samples, then use all samples')
    parser.add_argument('--trials', type=int, default=5,
                        help='number of trials at each clustering number')
    parser.add_argument('--enable-distance-map', action=EnableAction,
                        help='generate HEALPix map of distance estimates')
    parser.add_argument('--enable-multiresolution', action=EnableAction,
                        default=True,
                        help='generate a multiresolution HEALPix map')
    parser.add_argument('--top-nside', type=int, default=16,
                        help='choose a start nside before HEALPix refinement '
                        'steps (must be a valid nside)')
    parser.add_argument('-j', '--jobs', type=int, default=1, const=None,
                        nargs='?', help='Number of threads')
    parser.add_argument('--instruments', metavar='H1|L1|V1|...', nargs='+',
                        help='instruments to store in FITS header')
    parser.add_argument('--objid', help='event ID to store in FITS header')
    parser.add_argument('--path', type=str, default=None,
                        help="The path of the dataset within the HDF5 file")
    parser.add_argument('--tablename', type=str, default='posterior_samples',
                        help='The name of the table to search for recursively '
                        'within the HDF5 file. By default, search for '
                        'posterior_samples')
    return parser


def main(args=None):
    _parser = parser()
    args = _parser.parse_args(args)

    # Late imports
    from .. import io
    from ..io.hdf5 import _remap_colnames
    from ..bayestar import rasterize
    from .. import version
    from astropy.table import Table
    from astropy.time import Time
    import numpy as np
    import os
    import sys
    import pickle
    from ..kde import Clustered2Plus1DSkyKDE, Clustered2DSkyKDE
    import logging

    log = logging.getLogger()

    log.info('reading samples')
    try:
        data = io.read_samples(args.samples.name, path=args.path,
                               tablename=args.tablename)
    except IOError:
        # FIXME: remove this code path once we support only HDF5
        data = Table.read(args.samples, format='ascii')
        _remap_colnames(data)

    if args.maxpts is not None and args.maxpts < len(data):
        log.info('taking random subsample of chain')
        data = data[np.random.choice(len(data), args.maxpts, replace=False)]
    try:
        dist = data['dist']
    except KeyError:
        try:
            dist = data['distance']
        except KeyError:
            dist = None

    if args.loadpost is None:
        if dist is None:
            if args.enable_distance_map:
                _parser.error("The posterior samples file '{}' does not have "
                              "a distance column named 'dist' or 'distance'. "
                              "Cannot generate distance map. If you do not "
                              "intend to generate a distance map, then add "
                              "the '--disable-distance-map' command line "
                              "argument.".format(args.samples.name))
            pts = np.column_stack((data['ra'], data['dec']))
        else:
            pts = np.column_stack((data['ra'], data['dec'], dist))
        if args.enable_distance_map:
            cls = Clustered2Plus1DSkyKDE
        else:
            cls = Clustered2DSkyKDE
        skypost = cls(pts, trials=args.trials, jobs=args.jobs)

        log.info('pickling')
        with open(os.path.join(args.outdir, 'skypost.obj'), 'wb') as out:
            pickle.dump(skypost, out)
    else:
        skypost = pickle.load(args.loadpost)
        skypost.jobs = args.jobs

    log.info('making skymap')
    hpmap = skypost.as_healpix(top_nside=args.top_nside)
    if not args.enable_multiresolution:
        hpmap = rasterize(hpmap)
    hpmap.meta.update(io.fits.metadata_for_version_module(version))
    hpmap.meta['creator'] = _parser.prog
    hpmap.meta['origin'] = 'LIGO/Virgo'
    hpmap.meta['gps_creation_time'] = Time.now().gps
    hpmap.meta['history'] = [
        '', 'Generated by running the following script:',
        ' '.join([_parser.prog] + sys.argv[1:])]
    if args.objid is not None:
        hpmap.meta['objid'] = args.objid
    if args.instruments:
        hpmap.meta['instruments'] = args.instruments
    if args.enable_distance_map:
        hpmap.meta['distmean'] = np.mean(dist)
        hpmap.meta['diststd'] = np.std(dist)

    keys = ['time', 'time_mean', 'time_maxl']
    for key in keys:
        try:
            time = data[key]
        except KeyError:
            continue
        else:
            hpmap.meta['gps_time'] = time.mean()
            break
    else:
        log.warning(
            'Cannot determine the event time from any of the columns %r', keys)

    io.write_sky_map(os.path.join(args.outdir, args.fitsoutname),
                     hpmap, nest=True)

ligo/skymap/tool/ligo_skymap_plot.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
#
# Copyright (C) 2011-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""
Plot an all-sky map on a Mollweide projection.
By default, plot in celestial coordinates (RA, Dec).

To plot in geographic coordinates (longitude, latitude) with major
coastlines overlaid, provide the ``--geo`` flag.

Public-domain cartographic data is courtesy of `Natural Earth
<http://www.naturalearthdata.com>`_ and processed with `MapShaper
<http://www.mapshaper.org>`_.
"""

from . import ArgumentParser, FileType, SQLiteType
from .matplotlib import figure_parser


def parser():
    parser = ArgumentParser(parents=[figure_parser])
    parser.add_argument(
        '--annotate', default=False, action='store_true',
        help='annotate plot with information about the event')
    parser.add_argument(
        '--contour', metavar='PERCENT', type=float, nargs='+',
        help='plot contour enclosing this percentage of'
        ' probability mass [may be specified multiple times, default: none]')
    parser.add_argument(
        '--colorbar', default=False, action='store_true',
        help='Show colorbar')
    parser.add_argument(
        '--radec', nargs=2, metavar='deg', type=float, action='append',
        default=[], help='right ascension (deg) and declination (deg) to mark')
    parser.add_argument(
        '--inj-database', metavar='FILE.sqlite', type=SQLiteType('r'),
        help='read injection positions from database')
    parser.add_argument(
        '--geo', action='store_true',
        help='Use a terrestrial reference frame with coordinates (lon, lat) '
             'instead of the celestial frame with coordinates (RA, dec) '
             'and draw continents on the map')
    parser.add_argument(
        '--projection', type=str, default='mollweide',
        choices=['mollweide', 'aitoff', 'globe', 'zoom'],
        help='Projection style')
    parser.add_argument(
        '--projection-center', metavar='CENTER',
        help='Specify the center for globe and zoom projections, e.g. 14h 10d')
    parser.add_argument(
        '--zoom-radius', metavar='RADIUS',
        help='Specify the radius for zoom projections, e.g. 4deg')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Late imports

    import os
    import json
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from ..io import fits
    from .. import plot
    from .. import postprocess
    import astropy_healpix as ah
    from astropy.coordinates import SkyCoord
    from astropy.time import Time
    from astropy import units as u

    skymap, metadata = fits.read_sky_map(opts.input.name, nest=None)
    nside = ah.npix_to_nside(len(skymap))

    # Convert sky map from probability to probability per square degree.
    deg2perpix = ah.nside_to_pixel_area(nside).to_value(u.deg**2)
    probperdeg2 = skymap / deg2perpix

    axes_args = {}
    if opts.geo:
        axes_args['projection'] = 'geo'
        obstime = Time(metadata['gps_time'], format='gps').utc.isot
        axes_args['obstime'] = obstime
    else:
        axes_args['projection'] = 'astro'
    axes_args['projection'] += ' ' + opts.projection
    if opts.projection_center is not None:
        axes_args['center'] = SkyCoord(opts.projection_center)
    if opts.zoom_radius is not None:
        axes_args['radius'] = opts.zoom_radius
    ax = plt.axes(**axes_args)
    ax.grid()

    # Plot sky map.
    vmax = probperdeg2.max()
    img = ax.imshow_hpx(
        (probperdeg2, 'ICRS'), nested=metadata['nest'], vmin=0., vmax=vmax)

    # Add colorbar.
    if opts.colorbar:
        cb = plot.colorbar(img)
        cb.set_label(r'prob. per deg$^2$')

    # Add contours.
    if opts.contour:
        cls = 100 * postprocess.find_greedy_credible_levels(skymap)
        cs = ax.contour_hpx(
            (cls, 'ICRS'), nested=metadata['nest'],
            colors='k', linewidths=0.5, levels=opts.contour)
        fmt = r'%g\%%' if rcParams['text.usetex'] else '%g%%'
        plt.clabel(cs, fmt=fmt, fontsize=6, inline=True)

    # Add continents.
    if opts.geo:
        geojson_filename = os.path.join(
            os.path.dirname(plot.__file__), 'ne_simplified_coastline.json')
        with open(geojson_filename, 'r') as geojson_file:
            geoms = json.load(geojson_file)['geometries']
        verts = [coord for geom in geoms
                 for coord in zip(*geom['coordinates'])]
        plt.plot(*verts, color='0.5', linewidth=0.5,
                 transform=ax.get_transform('world'))

    radecs = opts.radec
    if opts.inj_database:
        query = '''SELECT DISTINCT longitude, latitude FROM sim_inspiral AS si
                   INNER JOIN coinc_event_map AS cm1
                   ON (si.simulation_id = cm1.event_id)
                   INNER JOIN coinc_event_map AS cm2
                   ON (cm1.coinc_event_id = cm2.coinc_event_id)
                   WHERE cm2.event_id = ?
                   AND cm1.table_name = 'sim_inspiral'
                   AND cm2.table_name = 'coinc_event'
                   '''
        (ra, dec), = opts.inj_database.execute(
            query, (metadata['objid'],)).fetchall()
        radecs.append(np.rad2deg([ra, dec]).tolist())

    # Add markers (e.g., for injections or external triggers).
    for ra, dec in radecs:
        ax.plot_coord(
            SkyCoord(ra, dec, unit='deg'), '*',
            markerfacecolor='white', markeredgecolor='black', markersize=10)

    # Add a white outline to all text to make it stand out from the background.
    plot.outline_text(ax)

    if opts.annotate:
        text = []
        try:
            objid = metadata['objid']
        except KeyError:
            pass
        else:
            text.append('event ID: {}'.format(objid))
        if opts.contour:
            pp = np.round(opts.contour).astype(int)
            ii = np.round(np.searchsorted(np.sort(cls), opts.contour) *
                          deg2perpix).astype(int)
            for i, p in zip(ii, pp):
                # FIXME: use Unicode symbol instead of TeX '$^2$'
                # because of broken fonts on Scientific Linux 7.
                text.append('{:d}% area: {:,d} deg²'.format(p, i))
        ax.text(1, 1, '\n'.join(text), transform=ax.transAxes, ha='right')

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_airmass.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Make an airmass chart for a LIGO/Virgo probability sky map."""

import numpy as np

from . import ArgumentParser, FileType, HelpChoicesAction
from .matplotlib import figure_parser


def parser():
    from astropy.coordinates import EarthLocation
    site_names = EarthLocation.get_site_names()
    parser = ArgumentParser(parents=[figure_parser])
    parser.add_argument(
        '-v', '--verbose', action='store_true',
        help='Print airmass table to stdout')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '--time', help='UTC time')
    parser.add_argument(
        '--site', metavar='SITE', choices=site_names, help='Observatory site')
    parser.add_argument(
        '--help-site', action=HelpChoicesAction, choices=site_names)
    parser.add_argument(
        '--site-longitude', metavar='DEG', type=float,
        help='Observatory longitude on the WGS84 ellipsoid. '
        'Mutually exclusive with --site.')
    parser.add_argument(
        '--site-latitude', metavar='DEG', type=float,
        help='Observatory latitude on the WGS84 ellipsoid. '
        'Mutually exclusive with --site.')
    parser.add_argument(
        '--site-height', metavar='METERS', type=float,
        help='Observatory height from the WGS84 ellipsoid. '
        'Mutually exclusive with --site.')
    parser.add_argument(
        '--site-timezone',
        help='Observatory time zone, e.g. "US/Pacific". '
        'Mutually exclusive with --site.')
    return parser


def condition_secz(x):
    """Condition secz airmass formula: values <= 0 are below the horizon,
    which we map to infinite airmass.
    """
    return np.where(x <= 0, np.inf, x)


def clip_verylarge(x, max=1e300):
    return np.clip(x, -max, max)


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    # Late imports
    import operator
    import sys

    from astroplan import Observer
    from astroplan.plots import plot_airmass
    from astropy.coordinates import EarthLocation, SkyCoord
    from astropy.table import Table
    from astropy.time import Time
    from astropy import units as u
    from matplotlib import dates
    from matplotlib.cm import ScalarMappable
    from matplotlib.colors import Normalize
    from matplotlib.patches import Patch
    from matplotlib import pyplot as plt
    from tqdm import tqdm
    import pytz

    from ..io import fits
    from .. import moc
    from .. import plot  # noqa
    from ..extern.numpy.quantile import percentile

    if opts.site is None:
        if opts.site_longitude is None or opts.site_latitude is None:
            p.error('must specify either --site or both '
                    '--site-longitude and --site-latitude')
        location = EarthLocation(
            lon=opts.site_longitude * u.deg,
            lat=opts.site_latitude * u.deg,
            height=(opts.site_height or 0) * u.m)
        if opts.site_timezone is not None:
            location.info.meta = {'timezone': opts.site_timezone}
        observer = Observer(location)
    else:
        if not((opts.site_longitude is None) and
               (opts.site_latitude is None) and
               (opts.site_height is None) and
               (opts.site_timezone is None)):
            p.error('argument --site not allowed with arguments '
                    '--site-longitude, --site-latitude, '
                    '--site-height, or --site-timezone')
        observer = Observer.at_site(opts.site)

    m = fits.read_sky_map(opts.input.name, moc=True)

    # Make an empty airmass chart.
    t0 = Time(opts.time) if opts.time is not None else Time.now()
    t0 = observer.midnight(t0)
    ax = plot_airmass([], observer, t0, altitude_yaxis=True)

    # Remove the fake source and determine times that were used for the plot.
    del ax.lines[:]
    times = Time(np.linspace(*ax.get_xlim()), format='plot_date')

    theta, phi = moc.uniq2ang(m['UNIQ'])
    coords = SkyCoord(phi, 0.5 * np.pi - theta, unit='rad')
    prob = moc.uniq2pixarea(m['UNIQ']) * m['PROBDENSITY']

    levels = np.arange(90, 0, -10)
    nlevels = len(levels)
    percentiles = np.concatenate((50 - 0.5 * levels, 50 + 0.5 * levels))

    airmass = np.column_stack([
        percentile(
            condition_secz(coords.transform_to(observer.altaz(t)).secz),
            percentiles,
            weights=prob)
        for t in tqdm(times)])

    cmap = ScalarMappable(Normalize(0, 100), plt.get_cmap())
    for level, lo, hi in zip(levels, airmass[:nlevels], airmass[nlevels:]):
        ax.fill_between(
            times.plot_date,
            clip_verylarge(lo),  # Clip infinities to large but finite values
            clip_verylarge(hi),  # because fill_between cannot handle inf
            color=cmap.to_rgba(level), zorder=2)

    ax.legend(
        [Patch(facecolor=cmap.to_rgba(level)) for level in levels],
        ['{}%'.format(level) for level in levels])
    # ax.set_title('{} from {}'.format(m.meta['objid'], observer.name))

    # Adapted from astroplan
    start = times[0]
    twilights = [
        (times[0].datetime, 0.0),
        (observer.sun_set_time(
            Time(start), which='next').datetime, 0.0),
        (observer.twilight_evening_civil(
            Time(start), which='next').datetime, 0.1),
        (observer.twilight_evening_nautical(
            Time(start), which='next').datetime, 0.2),
        (observer.twilight_evening_astronomical(
            Time(start), which='next').datetime, 0.3),
        (observer.twilight_morning_astronomical(
            Time(start), which='next').datetime, 0.4),
        (observer.twilight_morning_nautical(
            Time(start), which='next').datetime, 0.3),
        (observer.twilight_morning_civil(
            Time(start), which='next').datetime, 0.2),
        (observer.sun_rise_time(
            Time(start), which='next').datetime, 0.1),
        (times[-1].datetime, 0.0),
    ]

    twilights.sort(key=operator.itemgetter(0))
    for i, twi in enumerate(twilights[1:], 1):
        if twi[1] != 0:
            ax.axvspan(twilights[i - 1][0], twilights[i][0],
                       ymin=0, ymax=1, color='grey', alpha=twi[1], linewidth=0)
        if twi[1] != 0.4:
            ax.axvspan(twilights[i - 1][0], twilights[i][0],
                       ymin=0, ymax=1, color='white', alpha=0.8 - 2 * twi[1],
                       zorder=3, linewidth=0)

    # Add local time axis
    timezone = (observer.location.info.meta or {}).get('timezone')
    if timezone:
        tzinfo = pytz.timezone(timezone)
        ax2 = ax.twiny()
        ax2.set_xlim(ax.get_xlim())
        ax2.set_xticks(ax.get_xticks())
        ax2.xaxis.set_major_formatter(dates.DateFormatter('%H:%M', tz=tzinfo))
        plt.setp(ax2.get_xticklabels(), rotation=-30, ha='right')
        ax2.set_xlabel("Time from {} [{}]".format(
            min(times).to_datetime(tzinfo).date(),
            timezone))

    if opts.verbose:
        # Write airmass table to stdout.
        times.format = 'isot'
        table = Table(masked=True)
        table['time'] = times
        table['sun_alt'] = np.ma.masked_greater_equal(
            observer.sun_altaz(times).alt, 0)
        table['sun_alt'].format = lambda x: '{}'.format(int(np.round(x)))
        for p, data in sorted(zip(percentiles, airmass)):
            table[str(p)] = np.ma.masked_invalid(data)
            table[str(p)].format = lambda x: '{:.01f}'.format(np.around(x, 1))
        table.write(sys.stdout, format='ascii.fixed_width')

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_observability.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Make an observability chart for a LIGO/Virgo probability sky map."""

import numpy as np

from . import ArgumentParser, FileType, HelpChoicesAction
from .matplotlib import figure_parser


def parser():
    from astropy.coordinates import EarthLocation
    site_names = EarthLocation.get_site_names()
    parser = ArgumentParser(parents=[figure_parser])
    parser.add_argument(
        '-v', '--verbose', action='store_true',
        help='Print airmass table to stdout')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '--time', help='UTC time')
    parser.add_argument(
        '--max-airmass', default=2.5, type=float, help='Maximum airmass')
    parser.add_argument(
        '--twilight', default='astronomical',
        choices=('astronomical', 'nautical', 'civil'),
        help='Twilight definition: astronomical (-18 degrees), '
        'nautical (-12 degrees), or civil (-6 degrees)')
    parser.add_argument(
        '--site', nargs='*', default=[], metavar='SITE', choices=site_names,
        help='Observatory site')
    parser.add_argument(
        '--help-site', action=HelpChoicesAction, choices=site_names)
    parser.add_argument(
        '--site-name', nargs='*', default=[], help='Observatory name.')
    parser.add_argument(
        '--site-longitude', nargs='*', default=[], metavar='DEG', type=float,
        help='Observatory longitude on the WGS84 ellipsoid.')
    parser.add_argument(
        '--site-latitude', nargs='*', default=[], metavar='DEG', type=float,
        help='Observatory latitude on the WGS84 ellipsoid.')
    parser.add_argument(
        '--site-height', nargs='*', default=[], metavar='METERS', type=float,
        help='Observatory height from the WGS84 ellipsoid.')
    return parser


def condition_secz(x):
    """Condition secz airmass formula: values <= 0 are below the horizon,
    which we map to infinite airmass.
    """
    return np.where(x <= 0, np.inf, x)


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    # Late imports
    from astroplan import (
        AirmassConstraint, AtNightConstraint, Observer, is_event_observable)
    from astropy.coordinates import EarthLocation, SkyCoord
    from astropy.time import Time
    from astropy import units as u
    from matplotlib import dates
    from matplotlib import pyplot as plt
    from tqdm import tqdm

    from ..io import fits
    from .. import moc
    from .. import plot  # noqa

    names = ('name', 'longitude', 'latitude', 'height')
    length0, *lengths = (
        len(getattr(opts, 'site_{}'.format(name))) for name in names)
    if not all(length0 == length for length in lengths):
        p.error('these options require equal numbers of arguments: {}'.format(
            ', '.join(
                '--site-{}'.format(name) for name in names)))

    observers = [Observer.at_site(site) for site in opts.site]
    for name, lon, lat, height in zip(
            opts.site_name, opts.site_longitude, opts.site_latitude,
            opts.site_height):
        location = EarthLocation(
            lon=lon * u.deg,
            lat=lat * u.deg,
            height=(height or 0) * u.m)
        observers.append(Observer(location, name=name))
    observers = list(reversed(observers))

    m = fits.read_sky_map(opts.input.name, moc=True)

    t0 = Time(opts.time) if opts.time is not None else Time.now()
    times = t0 + np.linspace(0, 1) * u.day

    theta, phi = moc.uniq2ang(m['UNIQ'])
    coords = SkyCoord(phi, 0.5 * np.pi - theta, unit='rad')
    prob = np.asarray(moc.uniq2pixarea(m['UNIQ']) * m['PROBDENSITY'])

    constraints = [
        getattr(AtNightConstraint, 'twilight_{}'.format(opts.twilight))(),
        AirmassConstraint(opts.max_airmass)]

    fig = plt.figure()
    width, height = fig.get_size_inches()
    fig.set_size_inches(width, (len(observers) + 1) / 16 * width)
    ax = plt.axes()
    locator = dates.AutoDateLocator()
    formatter = dates.DateFormatter('%H:%M')
    ax.set_xlim([times[0].plot_date, times[-1].plot_date])
    ax.xaxis.set_major_formatter(formatter)
    ax.xaxis.set_major_locator(locator)
    ax.set_xlabel("Time from {0} [UTC]".format(min(times).datetime.date()))
    plt.setp(ax.get_xticklabels(), rotation=30, ha='right')
    ax.set_yticks(np.arange(len(observers)))
    ax.set_yticklabels([observer.name for observer in observers])
    ax.yaxis.set_tick_params(left=False)
    ax.grid(axis='x')
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)

    for i, observer in enumerate(tqdm(observers)):
        observable = 100 * np.dot(prob, is_event_observable(
            constraints, observer, coords, times))
        ax.contourf(
            times.plot_date, [i - 0.4, i + 0.4], np.tile(observable, (2, 1)),
            levels=np.arange(10, 110, 10), cmap=plt.get_cmap().reversed())

    plt.tight_layout()

    # Show or save output.
    opts.output()

ligo/skymap/tool/ligo_skymap_plot_pp_samples.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
#
# Copyright (C) 2017-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Create a P-P plot to compare a posterior sample chain with a sky map."""

import re

import numpy as np

from . import ArgumentParser, FileType
from .matplotlib import figure_parser


def fmt(x, sigfigs, force_scientific=False):
    """Round and format a number in scientific notation."""
    places = sigfigs - int(np.floor(np.log10(x)))
    x_rounded = np.around(x, places)
    if places <= 0 and not force_scientific:
        return '{:d}'.format(int(x_rounded))
    else:
        s = ('{:.' + str(sigfigs) + 'e}').format(x_rounded)
        return re.sub(r'^(.*)e\+?(-?)0*(\d+)$', r'$\1 \\times 10^{\2\3}$', s)


def parser():
    parser = ArgumentParser(parents=[figure_parser])
    parser.add_argument(
        'skymap', metavar='SKYMAP.fits[.gz]', type=FileType('rb'),
        help='FITS sky map file')
    parser.add_argument(
        'samples', metavar='SAMPLES.hdf5', type=FileType('rb'),
        help='HDF5 posterior sample chain file')
    parser.add_argument(
        '-m', '--max-points', type=int,
        help='Maximum number of posterior sample points '
        '[default: all of them]')
    parser.add_argument(
        '-p', '--contour', default=[], nargs='+', type=float,
        metavar='PERCENT',
        help='Report the areas and volumes within the smallest contours '
        'containing this much probability.')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    # Late imports.
    from astropy.table import Table
    from matplotlib import pyplot as plt
    from scipy.interpolate import interp1d
    from .. import io
    from .. import plot as _  # noqa: F401
    from ..postprocess import find_injection_moc

    # Read input.
    skymap = io.read_sky_map(args.skymap.name, moc=True)
    chain = io.read_samples(args.samples.name)

    # If required, downselect to a smaller number of posterior samples.
    if args.max_points is not None:
        chain = Table(np.random.permutation(chain)[:args.max_points],
                      copy=False)

    # Calculate P-P plot.
    contours = np.asarray(args.contour)
    result = find_injection_moc(skymap,
                                chain['ra'], chain['dec'], chain['dist'],
                                contours=1e-2 * contours)

    # Make Matplotlib figure.
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='pp_plot')
    ax.add_diagonal()
    ax.add_series(result.searched_prob, label='R.A., Dec.')
    searched_area_func = interp1d(np.linspace(0, 1, len(chain)),
                                  np.sort(result.searched_area),
                                  bounds_error=False)
    if 'DISTMU' in skymap.colnames:
        ax.add_series(result.searched_prob_dist, label='Distance')
        ax.add_series(result.searched_prob_vol, label='Volume')
        searched_vol_func = interp1d(np.linspace(0, 1, len(chain)),
                                     np.sort(result.searched_vol),
                                     bounds_error=False)
    for p, area, vol in zip(
            args.contour, result.contour_areas, result.contour_vols):
        text = '{:g}%\n{} deg$^2$'.format(p, fmt(area, 2))
        if 'DISTMU' in skymap.colnames:
            text += '\n{} Mpc$^3$'.format(fmt(vol, 2, force_scientific=True))
        ax.annotate(
            text, (1e-2 * p, 1e-2 * p), (0, -150),
            xycoords='data', textcoords='offset points',
            horizontalalignment='right', backgroundcolor='white',
            arrowprops=dict(connectionstyle='bar,angle=0,fraction=0',
                            arrowstyle='-|>', linewidth=2, color='black'))
        area = searched_area_func(1e-2 * p)
        text = '{:g}%\n{} deg$^2$'.format(p, fmt(area, 2))
        if 'DISTMU' in skymap.colnames:
            vol = searched_vol_func(1e-2 * p)
            text += '\n{} Mpc$^3$'.format(fmt(vol, 2, force_scientific=True))
        ax.annotate(
            text, (1e-2 * p, 1e-2 * p), (-75, 0),
            xycoords='data', textcoords='offset points',
            horizontalalignment='right', verticalalignment='center',
            backgroundcolor='white',
            arrowprops=dict(connectionstyle='bar,angle=0,fraction=0',
                            arrowstyle='-|>', linewidth=2, color='black'))
    ax.set_xlabel('searched probability')
    ax.set_ylabel('cumulative fraction of posterior samples')
    ax.set_title(args.skymap.name)
    ax.legend()
    ax.grid()

    # Show or save output.
    args.output()

ligo/skymap/tool/ligo_skymap_plot_stats.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Create summary plots for sky maps of found injections, optionally binned
cumulatively by false alarm rate or SNR.
"""

from distutils.dir_util import mkpath
import os

from . import ArgumentParser, FileType
from .matplotlib import MatplotlibFigureType


def parser():
    parser = ArgumentParser()
    parser.add_argument('--cumulative', action='store_true')
    parser.add_argument('--normed', action='store_true')
    parser.add_argument(
        '--group-by', choices=('far', 'snr'), metavar='far|snr',
        help='Group plots by false alarm rate (FAR) or ' +
        'signal to noise ratio (SNR)')
    parser.add_argument(
        '--pp-confidence-interval', type=float, metavar='PCT',
        default=95, help='If all inputs files have the same number of '
        'samples, overlay binomial confidence bands for this percentage on '
        'the P--P plot')
    parser.add_argument(
        '--format', default='pdf', help='Matplotlib format')
    parser.add_argument(
        'input', type=FileType('rb'), nargs='+',
        help='Name of input file generated by ligo-skymap-stats')
    parser.add_argument(
        '--output', '-o', default='.', help='output directory')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Imports.
    from astropy.table import Table
    import matplotlib
    matplotlib.use('agg')
    from matplotlib import pyplot as plt
    from matplotlib import rcParams
    import numpy as np
    from tqdm import tqdm
    from .. import plot  # noqa

    # Read in all of the datasets listed as positional command line arguments.
    datasets = [Table.read(file, format='ascii') for file in opts.input]

    # Determine plot colors and labels
    filenames = [file.name for file in opts.input]
    labels = [os.path.splitext(os.path.basename(f))[0] for f in filenames]
    if rcParams['text.usetex']:
        labels = [r'\verb/' + label + '/' for label in labels]
    rcParams['savefig.format'] = opts.format
    metadata = MatplotlibFigureType.get_savefig_metadata(opts.format)

    # Normalize column names
    for dataset in datasets:
        if 'p_value' in dataset.colnames:
            dataset.rename_column('p_value', 'searched_prob')

    if opts.group_by == 'far':

        def key_func(table):
            return -np.log10(table['far'])

        def key_to_dir(key):
            return 'far_1e{}'.format(-key)

        def key_to_title(key):
            return r'$\mathrm{{FAR}} \leq 10^{{{}}}$ Hz'.format(-key)

    elif opts.group_by == 'snr':

        def key_func(table):
            return table['snr']

        def key_to_dir(key):
            return 'snr_{}'.format(key)

        def key_to_title(key):
            return r'$\mathrm{{SNR}} \geq {}$'.format(key)

    else:

        def key_func(table):
            return np.zeros(len(table))

        def key_to_dir(key):
            return '.'

        def key_to_title(key):
            return 'All events'

    if opts.group_by is not None:
        missing = [filename for filename, dataset in zip(filenames, datasets)
                   if opts.group_by not in dataset.colnames]
        if missing:
            raise RuntimeError(
                'The following files had no "'
                + opts.group_by + '" column: ' + ' '.join(missing))

    for dataset in datasets:
        dataset['key'] = key_func(dataset)

    if opts.group_by is not None:
        invalid = [filename for filename, dataset in zip(filenames, datasets)
                   if not np.all(np.isfinite(dataset['key']))]
        if invalid:
            raise RuntimeError(
                'The following files had invalid values in the "'
                + opts.group_by + '" column: ' + ' '.join(invalid))

    keys = np.concatenate([dataset['key'] for dataset in datasets])

    histlabel = []
    if opts.cumulative:
        histlabel.append('cumulative')
    if opts.normed:
        histlabel.append('fraction')
    else:
        histlabel.append('number')
    histlabel.append('of injections')
    histlabel = ' '.join(histlabel)

    pp_plot_settings = [
        ['', 'searched posterior mass'],
        ['_dist', 'distance CDF at true distance'],
        ['_vol', 'searched volumetric probability']]
    hist_settings = [
        ['searched_area', 'searched_area (deg$^2$)'],
        ['searched_vol', 'searched volume (Mpc$^3$)'],
        ['offset', 'angle from true location and mode of posterior (deg)'],
        ['runtime', 'run time (s)']]

    keys = range(*np.floor([keys.min(), keys.max()+1]).astype(int))
    total = len(keys) * (len(pp_plot_settings) + len(hist_settings))
    with tqdm(total=total) as progress:
        for key in keys:
            filtered = [d[d['key'] >= key] for d in datasets]
            title = key_to_title(key)
            nsamples = {len(d) for d in filtered}
            if len(nsamples) == 1:
                nsamples, = nsamples
                title += ' ({} events)'.format(nsamples)
            else:
                nsamples = None

            subdir = os.path.join(opts.output, key_to_dir(key))
            mkpath(subdir)

            # Make several different kinds of P-P plots
            for suffix, xlabel in pp_plot_settings:
                colname = 'searched_prob' + suffix
                fig = plt.figure(figsize=(6, 6))
                ax = fig.add_subplot(111, projection='pp_plot')
                fig.subplots_adjust(bottom=0.15)
                ax.set_xlabel(xlabel)
                ax.set_ylabel('cumulative fraction of injections')
                ax.set_title(title)
                for d, label in zip(filtered, labels):
                    ax.add_series(d.columns.get(colname, []), label=label)
                ax.add_diagonal()
                if nsamples:
                    ax.add_confidence_band(
                        nsamples, 0.01 * opts.pp_confidence_interval)
                ax.grid()
                if len(filtered) > 1:
                    ax.legend(loc='lower right')
                fig.savefig(os.path.join(subdir, colname),
                            metadata=metadata)
                plt.close()
                progress.update()

            # Make several different kinds of histograms
            for colname, xlabel in hist_settings:
                fig = plt.figure(figsize=(6, 4.5))
                ax = fig.add_subplot(111)
                fig.subplots_adjust(bottom=0.15)
                ax.set_xscale('log')
                ax.set_xlabel(xlabel)
                ax.set_ylabel(histlabel)
                ax.set_title(title)
                values = np.concatenate(
                    [d.columns.get(colname, []) for d in filtered])
                if len(values) > 0:
                    bins = np.geomspace(np.min(values), np.max(values),
                                        1000 if opts.cumulative else 20)
                    for d, label in zip(filtered, labels):
                        ax.hist(d.columns.get(colname, []),
                                cumulative=opts.cumulative,
                                density=opts.normed, histtype='step',
                                bins=bins, label=label)
                ax.grid()
                ax.legend(loc='upper left')
                fig.savefig(os.path.join(subdir, colname + '_hist'),
                            metadata=metadata)
                plt.close()
                progress.update()

ligo/skymap/tool/ligo_skymap_plot_volume.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
249  
250  
251  
#
# Copyright (C) 2015-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Plot a volumetric posterior in three-projection view."""


from . import ArgumentParser, FileType
from .matplotlib import figure_parser


def parser():
    parser = ArgumentParser(parents=[figure_parser])
    parser.add_argument(
        '--annotate', default=False, action='store_true',
        help='annotate plot with information about the event')
    parser.add_argument(
        '--max-distance', metavar='Mpc', type=float,
        help='maximum distance of plot in Mpc')
    parser.add_argument(
        '--contour', metavar='PERCENT', type=float, nargs='+',
        help='plot contour enclosing this percentage of'
        ' probability mass')
    parser.add_argument(
        '--radecdist', nargs=3, type=float, action='append', default=[],
        help='right ascension (deg), declination (deg), and distance to mark')
    parser.add_argument(
        '--chain', metavar='CHAIN.hdf5', type=FileType('rb'),
        help='optionally plot a posterior sample chain')
    parser.add_argument(
        '--projection', type=int, choices=list(range(4)), default=0,
        help='Plot one specific projection, or 0 for all projections')
    parser.add_argument(
        'input', metavar='INPUT.fits[.gz]', type=FileType('rb'),
        default='-', nargs='?', help='Input FITS file')
    parser.add_argument(
        '--align-to', metavar='SKYMAP.fits[.gz]', type=FileType('rb'),
        help='Align to the principal axes of this sky map')
    parser.set_defaults(figure_width='3.5', figure_height='3.5')
    return parser


def main(args=None):
    opts = parser().parse_args(args)

    # Create progress bar.
    from tqdm import tqdm
    progress = tqdm()
    progress.set_description('Starting up')

    # Late imports
    import astropy_healpix as ah
    from matplotlib import pyplot as plt
    from matplotlib import gridspec
    from matplotlib import transforms
    from .. import io
    from ..plot import marker
    from ..distance import (parameters_to_marginal_moments, principal_axes,
                            volume_render, marginal_pdf)
    import healpy as hp
    import numpy as np
    import scipy.stats

    # Read input, determine input resolution.
    progress.set_description('Loading FITS file')
    (prob, mu, sigma, norm), metadata = io.read_sky_map(
        opts.input.name, distances=True)
    npix = len(prob)
    nside = ah.npix_to_nside(npix)

    progress.set_description('Preparing projection')

    if opts.align_to is None or opts.input.name == opts.align_to.name:
        prob2, mu2, sigma2, norm2 = prob, mu, sigma, norm
    else:
        (prob2, mu2, sigma2, norm2), _ = io.read_sky_map(
            opts.align_to.name, distances=True)
    if opts.max_distance is None:
        mean, std = parameters_to_marginal_moments(prob2, mu2, sigma2)
        max_distance = mean + 2.5 * std
    else:
        max_distance = opts.max_distance
    rot = np.ascontiguousarray(principal_axes(prob2, mu2, sigma2))

    if opts.chain:
        chain = io.read_samples(opts.chain.name)
        chain = np.dot(rot.T, (hp.ang2vec(
            0.5 * np.pi - chain['dec'], chain['ra']) *
            np.atleast_2d(chain['dist']).T).T)

    fig = plt.figure(frameon=False)
    n = 1 if opts.projection else 2
    gs = gridspec.GridSpec(
        n, n, left=0.01, right=0.99, bottom=0.01, top=0.99,
        wspace=0.05, hspace=0.05)

    imgwidth = int(opts.dpi * opts.figure_width / n)
    s = np.linspace(-max_distance, max_distance, imgwidth)
    xx, yy = np.meshgrid(s, s)

    truth_marker = marker.reticle(
        inner=0.5 * np.sqrt(2), outer=1.5 * np.sqrt(2), angle=45)

    for iface, (axis0, axis1, (sp0, sp1)) in enumerate((
            (1, 0, [0, 0]),
            (0, 2, [1, 1]),
            (1, 2, [1, 0]),)):

        if opts.projection and opts.projection != iface + 1:
            continue

        progress.set_description('Plotting projection {0}'.format(iface + 1))

        # Marginalize onto the given face
        density = volume_render(
            xx.ravel(), yy.ravel(), max_distance, axis0, axis1, rot, False,
            prob, mu, sigma, norm).reshape(xx.shape)

        # Plot heat map
        ax = fig.add_subplot(
            gs[0, 0] if opts.projection else gs[sp0, sp1], aspect=1)
        ax.imshow(
            density, origin='lower',
            extent=[-max_distance, max_distance, -max_distance, max_distance],
            cmap=opts.colormap)

        # Add contours if requested
        if opts.contour:
            flattened_density = density.ravel()
            indices = np.argsort(flattened_density)[::-1]
            cumsum = np.empty_like(flattened_density)
            cs = np.cumsum(flattened_density[indices])
            cumsum[indices] = cs / cs[-1] * 100
            cumsum = np.reshape(cumsum, density.shape)
            u, v = np.meshgrid(s, s)
            contourset = ax.contour(
                u, v, cumsum, levels=opts.contour, linewidths=0.5)

        # Mark locations
        ax._get_lines.get_next_color()  # skip default color
        for ra, dec, dist in opts.radecdist:
            theta = 0.5 * np.pi - np.deg2rad(dec)
            phi = np.deg2rad(ra)
            xyz = np.dot(rot.T, hp.ang2vec(theta, phi) * dist)
            ax.plot(
                xyz[axis0], xyz[axis1], marker=truth_marker,
                markerfacecolor='none', markeredgewidth=1)

        # Plot chain
        if opts.chain:
            ax.plot(chain[axis0], chain[axis1], '.k', markersize=0.5)

        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])

        # Set axis limits
        ax.set_xlim([-max_distance, max_distance])
        ax.set_ylim([-max_distance, max_distance])

        # Mark origin (Earth)
        ax.plot(
            [0], [0], marker=marker.earth, markersize=5,
            markerfacecolor='none', markeredgecolor='black',
            markeredgewidth=0.75)

        if iface == 2:
            ax.invert_xaxis()

    # Add contour labels if contours requested
    if opts.contour:
        ax.clabel(contourset, fmt='%d%%', fontsize=7)

    if not opts.projection:
        # Add scale bar, 1/4 width of the plot
        ax.plot(
            [0.0625, 0.3125], [0.0625, 0.0625],
            color='black', linewidth=1, transform=ax.transAxes)
        ax.text(
            0.0625, 0.0625,
            '{0:d} Mpc'.format(int(np.round(0.5 * max_distance))),
            fontsize=8, transform=ax.transAxes, verticalalignment='bottom')

        # Create marginal distance plot.
        progress.set_description('Plotting distance')
        gs1 = gridspec.GridSpecFromSubplotSpec(5, 5, gs[0, 1])
        ax = fig.add_subplot(gs1[1:-1, 1:-1])

        # Plot marginal distance distribution, integrated over the whole sky.
        d = np.linspace(0, max_distance)
        ax.fill_between(d, marginal_pdf(d, prob, mu, sigma, norm),
                        alpha=0.5, color=ax._get_lines.get_next_color())

        # Plot conditional distance distribution at true position
        # and mark true distance.
        for ra, dec, dist in opts.radecdist:
            theta = 0.5 * np.pi - np.deg2rad(dec)
            phi = np.deg2rad(ra)
            ipix = hp.ang2pix(nside, theta, phi)
            lines, = ax.plot(
                [dist], [-0.15], marker=truth_marker,
                markerfacecolor='none', markeredgewidth=1, clip_on=False,
                transform=transforms.blended_transform_factory(
                    ax.transData, ax.transAxes))
            ax.fill_between(d, scipy.stats.norm(
                mu[ipix], sigma[ipix]).pdf(d) * norm[ipix] * np.square(d),
                alpha=0.5, color=lines.get_color())
            ax.axvline(dist, color='black', linewidth=0.5)

        # Scale axes
        ax.set_xticks([0, max_distance])
        ax.set_xticklabels(
            ['0', "{0:d}\nMpc".format(int(np.round(max_distance)))],
            fontsize=9)
        ax.set_yticks([])
        ax.set_xlim(0, max_distance)
        ax.set_ylim(0, ax.get_ylim()[1])

        if opts.annotate:
            text = []
            try:
                objid = metadata['objid']
            except KeyError:
                pass
            else:
                text.append('event ID: {}'.format(objid))
            try:
                distmean = metadata['distmean']
                diststd = metadata['diststd']
            except KeyError:
                pass
            else:
                text.append('distance: {}±{} Mpc'.format(
                            int(np.round(distmean)), int(np.round(diststd))))
            ax.text(0, 1, '\n'.join(text), transform=ax.transAxes, fontsize=7,
                    ha='left', va='bottom', clip_on=False)

    progress.set_description('Saving')
    opts.output()

ligo/skymap/tool/ligo_skymap_stats.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
156  
157  
158  
159  
160  
161  
162  
163  
164  
165  
166  
167  
168  
169  
170  
171  
172  
173  
174  
175  
176  
177  
178  
179  
180  
181  
182  
183  
184  
185  
186  
187  
188  
189  
190  
191  
192  
193  
194  
195  
196  
197  
198  
199  
200  
201  
202  
203  
204  
205  
206  
207  
208  
209  
210  
211  
212  
213  
214  
215  
216  
217  
218  
219  
220  
221  
222  
223  
224  
225  
226  
227  
228  
229  
230  
231  
232  
233  
234  
235  
236  
237  
238  
239  
240  
241  
242  
243  
244  
245  
246  
247  
248  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""
Calculate summary statistics for a batch of sky maps.

Under the hood, this script is little more than a command-line interface for
the :mod:`ligo.skymap.postprocess.crossmatch` module.

The filenames of the sky maps may be provided as positional command line
arguments, and may also be provided as globs (such as ``*.fits.gz``). If
supplied with the optional ``--database`` argument, then also match sky maps
with injections from an inspinjfind-style sqlite database.

All angular separations are in degrees, all areas are in square degrees, and
all volumes are in cubic megaparsecs. The output is written as tab-separated
values with the following columns:

+-----------------------------------------------------------------------------+
| **From the search pipeline database**                                       |
+------------------------+----------------------------------------------------+
| ``coinc_event_id``     | event identifier                                   |
+------------------------+----------------------------------------------------+
| ``simulation_id``      | injection identifier                               |
+------------------------+----------------------------------------------------+
| ``far``                | false alarm rate                                   |
+------------------------+----------------------------------------------------+
| ``snr``                | signal to noise ratio                              |
+------------------------+----------------------------------------------------+
| **Injection finding**                                                       |
+------------------------+----------------------------------------------------+
| ``searched_area``      | area of 2D credible region containing the true sky |
|                        | location                                           |
+------------------------+----------------------------------------------------+
| ``searched_prob``      | probability in that 2D credible region             |
+------------------------+----------------------------------------------------+
| ``searched_prob_dist`` | marginal distance CDF at the true distance         |
+------------------------+----------------------------------------------------+
| ``searched_vol``       | volume of 3D credible region containing the true   |
|                        | position                                           |
+------------------------+----------------------------------------------------+
| ``searched_prob_vol``  | probability contained in that volume               |
+------------------------+----------------------------------------------------+
| ``offset``             | angular separation between the maximum             |
|                        | *a posteriori* position and the true sky position  |
+------------------------+----------------------------------------------------+
| **Additional metadata from the sky maps**                                   |
+------------------------+----------------------------------------------------+
| ``runtime``            | wall clock run time to generate sky map            |
+------------------------+----------------------------------------------------+
| ``distmean``           | mean *a posteriori* distance                       |
+------------------------+----------------------------------------------------+
| ``diststd``            | *a posteriori* standard deviation of distance      |
+------------------------+----------------------------------------------------+
| ``log_bci``            | natural log Bayes factor, coherent vs. incoherent  |
+------------------------+----------------------------------------------------+
| ``log_bsn``            | natural log Bayes factor, signal vs. noise         |
+------------------------+----------------------------------------------------+
| **Credible levels** (if ``--area`` or ``--contour`` options present)        |
+------------------------+----------------------------------------------------+
| ``area(P)``            | area of the *P* percent 2D credible region         |
+------------------------+----------------------------------------------------+
| ``prob(A)``            | probability contained within the 2D credible level |
|                        | of area *A*                                        |
+------------------------+----------------------------------------------------+
| ``dist(P)``            | distance for a cumulative marginal probability of  |
|                        | *P* percent                                        |
+------------------------+----------------------------------------------------+
| ``vol(P)``             | volume of the *P* percent 3D credible region       |
+------------------------+----------------------------------------------------+
| **Modes** (if ``--modes`` option is present)                                |
+------------------------+----------------------------------------------------+
| ``searched_modes``     | number of simply connected figures in the 2D       |
|                        | credible region containing the true sky location   |
+------------------------+----------------------------------------------------+
| ``modes(P)``           | number of simply connected figures in the *P*      |
|                        | percent 2D credible region                         |
+------------------------+----------------------------------------------------+

"""

from functools import partial
import sys

from astropy.coordinates import SkyCoord
from astropy import units as u
import numpy as np

from . import ArgumentParser, FileType, SQLiteType
from ..io import fits
from ..postprocess import crossmatch


def parser():
    parser = ArgumentParser()
    parser.add_argument(
        '-o', '--output', metavar='OUT.dat', type=FileType('w'), default='-',
        help='Name of output file')
    parser.add_argument(
        '-j', '--jobs', type=int, default=1, const=None, nargs='?',
        help='Number of threads')
    parser.add_argument(
        '-p', '--contour', default=[], nargs='+', type=float,
        metavar='PERCENT',
        help='Report the area of the smallest contour and the number of modes '
        'containing this much probability.')
    parser.add_argument(
        '-a', '--area', default=[], nargs='+', type=float, metavar='DEG2',
        help='Report the largest probability contained within any region '
        'of this area in square degrees. Can be repeated multiple times.')
    parser.add_argument(
        '--modes', action='store_true',
        help='Compute number of disjoint modes')
    parser.add_argument(
        '-d', '--database', type=SQLiteType('r'), metavar='DB.sqlite',
        help='Input SQLite database from search pipeline')
    parser.add_argument(
        'fitsfilenames', metavar='GLOB.fits[.gz]', nargs='+', action='glob',
        help='Input FITS filenames and/or globs')
    parser.add_argument(
        '--cosmology', action='store_true',
        help='Report volume localizations as comoving volumes.')
    return parser


def process(fitsfilename, db, contours, modes, areas, cosmology):
    sky_map = fits.read_sky_map(fitsfilename, moc=True)

    coinc_event_id = sky_map.meta.get('objid')
    try:
        runtime = sky_map.meta['runtime']
    except KeyError:
        runtime = float('nan')

    contour_pvalues = 0.01 * np.asarray(contours)

    if db is None:
        simulation_id = true_ra = true_dec = true_dist = far = snr = None
    else:
        row = db.execute(
            """
            SELECT DISTINCT sim.simulation_id AS simulation_id,
            sim.longitude AS ra, sim.latitude AS dec, sim.distance AS distance,
            ci.combined_far AS far, ci.snr AS snr
            FROM coinc_event_map AS cem1 INNER JOIN coinc_event_map AS cem2
            ON (cem1.coinc_event_id = cem2.coinc_event_id)
            INNER JOIN sim_inspiral AS sim
            ON (cem1.event_id = sim.simulation_id)
            INNER JOIN coinc_inspiral AS ci
            ON (cem2.event_id = ci.coinc_event_id)
            WHERE cem1.table_name = 'sim_inspiral'
            AND cem2.table_name = 'coinc_event' AND cem2.event_id = ?
            """, (coinc_event_id,)).fetchone()
        if row is None:
            return None
        simulation_id, true_ra, true_dec, true_dist, far, snr = row

    if true_ra is None or true_dec is None:
        true_coord = None
    elif true_dist is None:
        true_coord = SkyCoord(true_ra * u.rad, true_dec * u.rad)
    else:
        true_coord = SkyCoord(true_ra * u.rad, true_dec * u.rad,
                              true_dist * u.Mpc)

    (
        searched_area, searched_prob, offset, searched_modes, contour_areas,
        area_probs, contour_modes, searched_prob_dist, contour_dists,
        searched_vol, searched_prob_vol, contour_vols, probdensity,
        probdensity_vol
    ) = crossmatch(
        sky_map, true_coord,
        contours=contour_pvalues, areas=areas, modes=modes, cosmology=cosmology
    )

    if snr is None:
        snr = np.nan
    if far is None:
        far = np.nan
    distmean = sky_map.meta.get('distmean', np.nan)
    diststd = sky_map.meta.get('diststd', np.nan)
    log_bci = sky_map.meta.get('log_bci', np.nan)
    log_bsn = sky_map.meta.get('log_bsn', np.nan)

    ret = [coinc_event_id]
    if db is not None:
        ret += [
            simulation_id, far, snr, searched_area,
            searched_prob, searched_prob_dist, searched_vol, searched_prob_vol,
            offset]
    ret += [runtime, distmean, diststd, log_bci, log_bsn]
    ret += contour_areas + area_probs + contour_dists + contour_vols
    if modes:
        if db is not None:
            ret += [searched_modes]
        ret += contour_modes
    return ret


def main(args=None):
    p = parser()
    opts = p.parse_args(args)

    from ..util.progress import progress_map
    from .. import omp

    if opts.jobs != 1:
        omp.num_threads = 1  # disable OpenMP parallelism

    if args is None:
        print('#', *sys.argv, file=opts.output)
    else:
        print('#', p.prog, *args, file=opts.output)

    colnames = ['coinc_event_id']
    if opts.database is not None:
        colnames += ['simulation_id', 'far', 'snr', 'searched_area',
                     'searched_prob', 'searched_prob_dist', 'searched_vol',
                     'searched_prob_vol', 'offset']
    colnames += ['runtime', 'distmean', 'diststd', 'log_bci', 'log_bsn']
    colnames += ['area({0:g})'.format(_) for _ in opts.contour]
    colnames += ['prob({0:g})'.format(_) for _ in opts.area]
    colnames += ['dist({0:g})'.format(_) for _ in opts.contour]
    colnames += ['vol({0:g})'.format(_) for _ in opts.contour]
    if opts.modes:
        if opts.database is not None:
            colnames += ['searched_modes']
        colnames += ["modes({0:g})".format(p) for p in opts.contour]
    print(*colnames, sep="\t", file=opts.output)

    func = partial(process, db=opts.database, contours=opts.contour,
                   modes=opts.modes, areas=opts.area, cosmology=opts.cosmology)
    for record in progress_map(func, opts.fitsfilenames, jobs=opts.jobs):
        if record is not None:
            print(*record, sep="\t", file=opts.output)

ligo/skymap/tool/ligo_skymap_unflatten.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
"""Convert a HEALPix FITS file to multi-resolution UNIQ indexing from the more
common IMPLICIT indexing.
"""

from . import ArgumentParser, FileType


def parser():
    parser = ArgumentParser()
    parser.add_argument('input', metavar='INPUT.fits',
                        type=FileType('rb'), help='Input FITS file')
    parser.add_argument('output', metavar='OUTPUT.fits[.gz]',
                        type=FileType('wb'), help='Output FITS file')
    return parser


def main(args=None):
    args = parser().parse_args(args)

    import warnings
    from astropy.io import fits
    from ..io import read_sky_map, write_sky_map

    hdus = fits.open(args.input)
    ordering = hdus[1].header['ORDERING']
    expected_orderings = {'NESTED', 'RING'}
    if ordering not in expected_orderings:
        msg = 'Expected the FITS file {} to have ordering {}, but it is {}'
        warnings.warn(msg.format(
            args.input.name, ' or '.join(expected_orderings), ordering))
    table = read_sky_map(hdus, moc=True)
    write_sky_map(args.output.name, table)

ligo/skymap/tool/matplotlib.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
#
# Copyright (C) 2013-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Functions that support the command line interface for plotting tools."""

import argparse
import os
import sys

import matplotlib

from ..plot import cmap  # noqa
from . import FileType, HelpChoicesAction, type_with_sideeffect, version_string

# Set no-op Matplotlib backend to defer importing anything that requires a GUI
# until we have determined that it is necessary based on the command line
# arguments.
if 'matplotlib.pyplot' in sys.modules:
    from matplotlib import pyplot as plt
    plt.switch_backend('Template')
else:
    matplotlib.use('Template', warn=False, force=True)
    from matplotlib import pyplot as plt

__all__ = ('figure_parser',)


class MatplotlibFigureType(FileType):

    def __init__(self):
        super().__init__('wb')

    @staticmethod
    def __show():
        from matplotlib import pyplot as plt
        return plt.show()

    @staticmethod
    def get_savefig_metadata(format):
        program, _ = os.path.splitext(os.path.basename(sys.argv[0]))
        cmdline = ' '.join([program] + sys.argv[1:])
        metadata = {'Title': cmdline}
        if format == 'png':
            metadata['Software'] = version_string
        elif format in {'pdf', 'ps', 'eps'}:
            metadata['Creator'] = version_string
        return metadata

    def __save(self):
        from matplotlib import pyplot as plt
        _, ext = os.path.splitext(self.string)
        format = ext.lower().lstrip('.')
        metadata = self.get_savefig_metadata(format)
        return plt.savefig(self.string, metadata=metadata)

    def __call__(self, string):
        from matplotlib import pyplot as plt
        if string == '-':
            plt.switch_backend(matplotlib.rcParamsOrig['backend'])
            return self.__show
        else:
            with super().__call__(string):
                pass
            plt.switch_backend('agg')
            self.string = string
            return self.__save


@type_with_sideeffect(str)
def colormap(value):
    from matplotlib import rcParams
    rcParams['image.cmap'] = value


@type_with_sideeffect(float)
def figwidth(value):
    from matplotlib import rcParams
    rcParams['figure.figsize'][0] = float(value)


@type_with_sideeffect(float)
def figheight(value):
    from matplotlib import rcParams
    rcParams['figure.figsize'][1] = float(value)


@type_with_sideeffect(int)
def dpi(value):
    from matplotlib import rcParams
    rcParams['figure.dpi'] = rcParams['savefig.dpi'] = float(value)


@type_with_sideeffect(int)
def transparent(value):
    from matplotlib import rcParams
    rcParams['savefig.transparent'] = bool(value)


figure_parser = argparse.ArgumentParser(add_help=False)
group = figure_parser.add_argument_group(
    'figure options', 'Options that affect figure output format')
group.add_argument(
    '-o', '--output', metavar='FILE.{pdf,png}',
    default='-', type=MatplotlibFigureType(),
    help='output file, or - to plot to screen')
group.add_argument(
    '--colormap', default='cylon', choices=plt.colormaps(), type=colormap,
    metavar='CMAP', help='matplotlib colormap')
group.add_argument(
    '--help-colormap', action=HelpChoicesAction, choices=plt.colormaps())
group.add_argument(
    '--figure-width', metavar='INCHES', type=figwidth, default='8',
    help='width of figure in inches')
group.add_argument(
    '--figure-height', metavar='INCHES', type=figheight, default='6',
    help='height of figure in inches')
group.add_argument(
    '--dpi', metavar='PIXELS', type=dpi, default=300,
    help='resolution of figure in dots per inch')
group.add_argument(
    '--transparent', const='1', default='0', nargs='?', type=transparent,
    help='Save image with transparent background')
del group

ligo/skymap/util/__init__.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
import os
import pkgutil

__all__ = ()

# Import all symbols from all submodules of this module.
for _, module, _ in pkgutil.iter_modules([os.path.dirname(__file__)]):
    if module not in {'tests'}:
        exec('from . import {0};'
             '__all__ += getattr({0}, "__all__", ());'
             'from .{0} import *'.format(module))
    del module

# Clean up
del os, pkgutil

ligo/skymap/util/file.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
#
# Copyright (C) 2018-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""File tools."""
import errno
import os
import shutil
import tempfile


def rename(src, dst):
    """Like `os.rename`, but works across different devices because it
    catches and handles ``EXDEV`` (``Invalid cross-device link``) errors.
    """
    try:
        os.rename(src, dst)
    except OSError as e:
        if e.errno == errno.EXDEV:
            dir, suffix = os.path.split(dst)
            tmpfid, tmpdst = tempfile.mkstemp(dir=dir, suffix=suffix)
            try:
                os.close(tmpfid)
                shutil.copy2(src, tmpdst)
                os.rename(tmpdst, dst)
            except:  # noqa: E722
                os.remove(tmpdst)
                raise
        else:
            raise


def rm_f(filename):
    """Remove a file, or be silent if the file does not exist, like ``rm -f``.

    Examples
    --------
    >>> with tempfile.TemporaryDirectory() as d:
    ...     rm_f('test')
    ...     with open('test', 'w') as f:
    ...         print('Hello world', file=f)
    ...     rm_f('test')

    """
    try:
        os.remove(filename)
    except FileNotFoundError:
        pass

ligo/skymap/util/ilwd.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
#
# Copyright (C) 2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for adapting LIGO-LW row ID formats."""
from ligo.lw.ligolw import Param
from ligo.lw.lsctables import TableByName
from ligo.lw.table import Column, TableStream
from ligo.lw.types import FormatFunc, FromPyType, IDTypes, ToPyType

__all__ = ('use_in',)

ROWID_PYTYPE = int
ROWID_TYPE = FromPyType[ROWID_PYTYPE]
ROWID_FORMATFUNC = FormatFunc[ROWID_TYPE]


def use_in(ContentHandler):
    """Convert from old-style to new-style row IDs on the fly.

    This is loosely adapted from :func:`ligo.lw.utils.ilwd.strip_ilwdchar`.

    Notes
    -----
    When building a ContentHandler, this must be the _outermost_ decorator,
    outside of :func:`ligo.lw.lsctables.use_in`, :func:`ligo.lw.param.use_in`,
    or :func:`ligo.lw.table.use_in`.

    Examples
    --------
    >>> from pkg_resources import resource_filename
    >>> from ligo.lw import array, ligolw, lsctables, param, table, utils
    >>> from ligo.skymap.util import ilwd
    >>> @ilwd.use_in
    ... @lsctables.use_in
    ... @param.use_in
    ... @table.use_in
    ... class ContentHandler(ligolw.LIGOLWContentHandler):
    ...     pass
    >>> filename = resource_filename(
    ...     'ligo.skymap.io.tests', 'data/G197392_coinc.xml.gz')
    >>> xmldoc = utils.load_filename(filename, contenthandler=ContentHandler)
    >>> table = lsctables.SnglInspiralTable.get_table(xmldoc)
    >>> table[0].process_id
    0

    """

    def endElementNS(self, uri_localname, qname,
                     __orig_endElementNS=ContentHandler.endElementNS):
        """Convert values of <Param> elements from ilwdchar to int."""
        if isinstance(self.current, Param) and self.current.Type in IDTypes:
            old_type = ToPyType[self.current.Type]
            new_value = ROWID_PYTYPE(old_type(self.current.pcdata))
            self.current.Type = ROWID_TYPE
            self.current.pcdata = ROWID_FORMATFUNC(new_value)
        __orig_endElementNS(self, uri_localname, qname)

    remapped = {}

    def startColumn(self, parent, attrs,
                    __orig_startColumn=ContentHandler.startColumn):
        """Convert types in <Column> elements from ilwdchar to int.

        Notes
        -----
        This method is adapted from
        :func:`ligo.lw.utils.ilwd.strip_ilwdchar`.

        """
        result = __orig_startColumn(self, parent, attrs)

        # If this is an ilwdchar column, then create a function to convert its
        # rows' values for use in the startStream method below.
        if result.Type in IDTypes:
            old_type = ToPyType[result.Type]

            def converter(old_value):
                return ROWID_PYTYPE(old_type(old_value))

            remapped[(id(parent), result.Name)] = converter
            result.Type = ROWID_TYPE

        # If this is an ilwdchar column, then normalize the column name.
        if parent.Name in TableByName:
            validcolumns = TableByName[parent.Name].validcolumns
            if result.Name not in validcolumns:
                stripped_column_to_valid_column = {
                    Column.ColumnName(name): name for name in validcolumns}
                if result.Name in stripped_column_to_valid_column:
                    result.setAttribute(
                        'Name', stripped_column_to_valid_column[result.Name])

        return result

    def startStream(self, parent, attrs,
                    __orig_startStream=ContentHandler.startStream):
        """Convert values in table <Stream> elements from ilwdchar to int.

        Notes
        -----
        This method is adapted from
        :meth:`ligo.lw.table.TableStream.config`.

        """
        result = __orig_startStream(self, parent, attrs)
        if isinstance(result, TableStream):
            loadcolumns = set(parent.columnnames)
            if parent.loadcolumns is not None:
                # FIXME:  convert loadcolumns attributes to sets to
                # avoid the conversion.
                loadcolumns &= set(parent.loadcolumns)
            result._tokenizer.set_types([
                (remapped.pop((id(parent), colname), pytype)
                 if colname in loadcolumns else None)
                for pytype, colname
                in zip(parent.columnpytypes, parent.columnnames)])
        return result

    ContentHandler.endElementNS = endElementNS
    ContentHandler.startColumn = startColumn
    ContentHandler.startStream = startStream

    return ContentHandler

ligo/skymap/util/numpy.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
#
# Copyright (C) 2018-2019  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import functools
import numpy as np

__all__ = ('add_newdoc_ufunc', 'require_contiguous_aligned')


def add_newdoc_ufunc(func, doc):  # pragma: no cover
    """The function `np.lib.add_newdoc_ufunc` can only change a ufunc's
    docstring if it is `NULL`. This workaround avoids an exception when the
    user tries to `reload()` this module.
    """
    try:
        np.lib.add_newdoc_ufunc(func, doc)
    except ValueError as e:
        msg = 'Cannot change docstring of ufunc with non-NULL docstring'
        if e.args[0] == msg:
            pass


def require_contiguous_aligned(func):
    """Wrap a Numpy ufunc to guarantee that all of its inputs are
    C-contiguous arrays.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        n = func.nin
        args = [arg if i >= n or np.isscalar(arg)
                else np.require(arg, requirements={'CONTIGUOUS', 'ALIGNED'})
                for i, arg in enumerate(args)]
        return func(*args, **kwargs)
    return wrapper

ligo/skymap/util/progress.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
#
# Copyright (C) 2019-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for progress bars"""

try:
    from billiard import Pool
except ImportError:
    from multiprocessing import Pool
from heapq import heappop, heappush
from operator import length_hint

from tqdm.auto import tqdm

__all__ = ('progress_map',)


class WrappedFunc:

    def __init__(self, func):
        self.func = func

    def __call__(self, i_args):
        i, args = i_args
        return i, self.func(*args)


def _get_total_estimate(*iterables):
    """Estimate total loop iterations for mapping over multiple iterables."""
    estimates = (length_hint(iterable, -1) for iterable in iterables)
    valid_estimates = (estimate for estimate in estimates if estimate != -1)
    return min(valid_estimates, default=None)


def _results_in_order(completed):
    """Put results back into order and yield them as quickly as they arrive."""
    heap = []
    current = 0
    for i_result in completed:
        i, result = i_result
        if i == current:
            yield result
            current += 1
            while heap and heap[0][0] == current:
                _, result = heappop(heap)
                yield result
                current += 1
        else:
            heappush(heap, i_result)
    assert not heap, 'The heap must be empty'


def progress_map(func, *iterables, jobs=1, **kwargs):
    """Map a function across iterables of arguments.

    This is comparable to :meth:`astropy.utils.console.ProgressBar.map`, except
    that it is implemented using :mod:`tqdm` and so provides more detailed and
    accurate progress information.
    """
    total = _get_total_estimate(*iterables)
    if jobs == 1:
        yield from tqdm(map(func, *iterables), total=total, **kwargs)
    else:
        with Pool(jobs) as pool:
            yield from _results_in_order(
                tqdm(
                    pool.imap_unordered(
                        WrappedFunc(func),
                        enumerate(zip(*iterables))
                    ),
                    total=total, **kwargs
                )
            )

ligo/skymap/util/sqlite.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
89  
90  
91  
92  
93  
94  
95  
96  
97  
98  
99  
100  
101  
102  
103  
104  
105  
106  
107  
108  
109  
110  
111  
112  
113  
114  
115  
116  
117  
118  
119  
120  
121  
122  
123  
124  
125  
126  
127  
128  
129  
130  
131  
132  
133  
134  
135  
136  
137  
138  
139  
140  
141  
142  
143  
144  
145  
146  
147  
148  
149  
150  
151  
152  
153  
154  
155  
#
# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Tools for reading and writing SQLite databases."""
import copyreg
import sqlite3
_open = open


def _open_a(string):
    return sqlite3.connect(string, check_same_thread=False)


def _open_r(string):
    return sqlite3.connect('file:{}?mode=ro'.format(string),
                           check_same_thread=False, uri=True)


def _open_w(string):
    with _open(string, 'wb'):
        pass
    return sqlite3.connect(string, check_same_thread=False)


_openers = {'a': _open_a, 'r': _open_r, 'w': _open_w}


def open(string, mode):
    """Open an SQLite database with an `open`-style mode flag.

    Parameters
    ----------
    string : str
        Path of the SQLite database file
    mode : {'r', 'w', 'a'}
        Access mode: read only, clobber and overwrite, or modify in place.

    Returns
    -------
    connection : `sqlite3.Connection`

    Raises
    ------
    ValueError
        If the filename is invalid (e.g. ``/dev/stdin``), or if the requested
        mode is invalid
    OSError
        If the database could not be opened in the specified mode

    Examples
    --------
    >>> import tempfile
    >>> import os
    >>> with tempfile.TemporaryDirectory() as d:
    ...     open(os.path.join(d, 'test.sqlite'), 'w')
    ...
    <sqlite3.Connection object at 0x...>

    >>> with tempfile.TemporaryDirectory() as d:
    ...     open(os.path.join(d, 'test.sqlite'), 'r')
    ...
    Traceback (most recent call last):
      ...
    OSError: Failed to open database ...

    >>> open('/dev/stdin', 'r')
    Traceback (most recent call last):
      ...
    ValueError: Cannot open stdin/stdout as an SQLite database

    >>> open('test.sqlite', 'x')
    Traceback (most recent call last):
      ...
    ValueError: Invalid mode "x". Must be one of "arw".

    """
    if string in {'-', '/dev/stdin', '/dev/stdout'}:
        raise ValueError('Cannot open stdin/stdout as an SQLite database')
    try:
        opener = _openers[mode]
    except KeyError:
        raise ValueError('Invalid mode "{}". Must be one of "{}".'.format(
            mode, ''.join(sorted(_openers.keys()))))
    try:
        return opener(string)
    except (OSError, sqlite3.Error) as e:
        raise OSError('Failed to open database {}: {}'.format(string, e))


def get_filename(connection):
    r"""Get the name of the file associated with an SQLite connection.

    Parameters
    ----------
    connection : `sqlite3.Connection`
        The database connection

    Returns
    -------
    str
        The name of the file that contains the SQLite database

    Raises
    ------
    RuntimeError
        If more than one database is attached to the connection

    Examples
    --------
    >>> import tempfile
    >>> import os
    >>> with tempfile.TemporaryDirectory() as d:
    ...     with sqlite3.connect(os.path.join(d, 'test.sqlite')) as db:
    ...         print(get_filename(db))
    ...
    /.../test.sqlite

    >>> with tempfile.TemporaryDirectory() as d:
    ...     with sqlite3.connect(os.path.join(d, 'test1.sqlite')) as db1, \
    ...          sqlite3.connect(os.path.join(d, 'test2.sqlite')) as db2:
    ...         filename = get_filename(db1)
    ...         db2.execute('ATTACH DATABASE "{}" AS db2'.format(filename))
    ...         print(get_filename(db2))
    ...
    Traceback (most recent call last):
      ...
    RuntimeError: Expected exactly one attached database

    """
    result = connection.execute('pragma database_list').fetchall()
    try:
        (_, _, filename), = result
    except ValueError:
        raise RuntimeError('Expected exactly one attached database')
    return filename


def pickle_sqlite3_connection(obj):
    return _open_a, (get_filename(obj),)


copyreg.pickle(sqlite3.Connection, pickle_sqlite3_connection)

ligo/skymap/util/stopwatch.py

1  
2  
3  
4  
5  
6  
7  
8  
9  
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  
39  
40  
41  
42  
43  
44  
45  
46  
47  
48  
49  
50  
51  
52  
53  
54  
55  
56  
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  
75  
76  
77  
78  
79  
80  
81  
82  
83  
84  
85  
86  
87  
88  
#
# Copyright (C) 2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Performance measurement utilities."""
from resource import getrusage, RUSAGE_SELF
from time import perf_counter

import numpy as np


class StopwatchTimes:

    def __init__(self, real=0, user=0, sys=0):
        self.real = real
        self.user = user
        self.sys = sys

    def __iadd__(self, other):
        self.real += other.real
        self.user += other.user
        self.sys += other.sys
        return self

    def __isub__(self, other):
        self.real -= other.real
        self.user -= other.user
        self.sys -= other.sys
        return self

    def __add__(self, other):
        return StopwatchTimes(self.real + other.real,
                              self.user + other.user,
                              self.sys + other.sys)

    def __sub__(self, other):
        return StopwatchTimes(self.real - other.real,
                              self.user - other.user,
                              self.sys - other.sys)

    def __repr__(self):
        return f'{self.__class__.__name__}(real={self.real!r}, user={self.user!r}, sys={self.sys!r})'  # noqa: E501

    def __str__(self):
        real, user, sys = (np.format_float_positional(val, 3, unique=False)
                           for val in (self.real, self.user, self.sys))
        return f'real={real}s, user={user}s, sys={sys}s'

    @classmethod
    def now(cls):
        rusage = getrusage(RUSAGE_SELF)
        return cls(perf_counter(), rusage.ru_utime, rusage.ru_stime)

    def reset(self):
        self.real = self.user = self.sys = 0


class Stopwatch(StopwatchTimes):
    """A code profiling utility that mimics the interface of a stopwatch."""