Loading [MathJax]/extensions/TeX/AMSsymbols.js
LALPulsar 7.1.1.1-3a66518
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
test_flat_prior.py
Go to the documentation of this file.
1"""
2A script to run lalpulsar_parameter_estimation_nested with increasing amplitude prior
3ranges. Samples supposedly generated from the prior will be output and checked using a KS test
4to see if they match the expected flat prior across the range.
5"""
6
7import os
8import sys
9import numpy as np
10import subprocess as sp
11import scipy.stats as ss
12import h5py
13
14if os.environ["LALINFERENCE_ENABLED"] == "false":
15 print("Skipping test: requires LALInference")
16 sys.exit(77)
17
18exit_code = 0
19
20execu = "./lalpulsar_parameter_estimation_nested" # executable
21
22# lalpulsar_parameter_estimation_nested runs much slower with memory debugging
23os.environ["LAL_DEBUG_LEVEL"] = os.environ["LAL_DEBUG_LEVEL"].replace("memdbg", "")
24print("Modified LAL_DEBUG_LEVEL='%s'" % os.environ["LAL_DEBUG_LEVEL"])
25
26# create files needed to run the code
27
28# par file
29parfile = """\
30PSRJ J0000+0000
31RAJ 00:00:00.0
32DECJ 00:00:00.0
33F0 100
34PEPOCH 54000"""
35
36parf = "test.par"
37f = open(parf, "w")
38f.write(parfile)
39f.close()
40
41# data file
42datafile = "data.txt.gz"
43ds = np.zeros((1440, 3))
44ds[:, 0] = np.linspace(900000000.0, 900000000.0 + 86400.0 - 60.0, 1440) # time stamps
45ds[:, -2:] = 1.0e-24 * np.random.randn(1440, 2)
46
47# output data file
48np.savetxt(datafile, ds, fmt="%.12e")
49
50# range of upper limits on h0 in prior file
51h0uls = [1e-22, 1e-21, 1e-20, 1e-19, 1e-18, 1e-17]
52
53# some default inputs
54dets = "H1"
55Nlive = "5000"
56Nmcmcinitial = "0"
57outfile = "test.hdf"
58outfile_SNR = "test_SNR"
59outfile_Znoise = "test_Znoise"
60priorsamples = Nlive
61
62for h, h0ul in enumerate(h0uls):
63 print("--- h0=%i/%i ---" % (h + 1, len(h0uls)), flush=True)
64
65 # prior file
66 priorfile = """\
67H0 uniform 0 %e
68PHI0 uniform 0 %f
69COSIOTA uniform -1 1
70PSI uniform 0 %f""" % (
71 h0ul,
72 np.pi,
73 np.pi / 2.0,
74 )
75
76 priorf = "test.prior"
77 f = open(priorf, "w")
78 f.write(priorfile)
79 f.close()
80
81 # run code
82 commandline = (
83 "%s --detectors %s --par-file %s --input-files %s --outfile %s --prior-file %s --Nlive %s --Nmcmcinitial %s --sampleprior %s"
84 % (
85 execu,
86 dets,
87 parf,
88 datafile,
89 outfile,
90 priorf,
91 Nlive,
92 Nmcmcinitial,
93 priorsamples,
94 )
95 )
96
97 sp.check_call(commandline, shell=True)
98
99 # read in prior samples
100 f = h5py.File(outfile, "r")
101 a = f["lalinference"]
102 h0samps = a["lalinference_nest"]["nested_samples"]["H0"][:]
103
104 # get normed histogram of samples
105 [n, nedges] = np.histogram(h0samps, bins=20, range=(0.0, h0ul), density=True)
106 nc = np.cumsum(n) * (nedges[1] - nedges[0])
107
108 stat, p = ss.kstest(nc, "uniform")
109
110 print("K-S test p-value for upper range of %e = %f" % (h0ul, p))
111
112 if p < 0.005:
113 print("There might be a problem for this prior distribution")
114 try:
115 import matplotlib as mpl
116
117 mpl.use("Agg")
118 import matplotlib.pyplot as pl
119 except ModuleNotFoundError:
120 print("matplotlib unavailable; skipping plot")
121 exit_code = 1
122 break
123 fig, ax = pl.subplots(1, 1)
124 ax.hist(
125 h0samps,
126 bins=20,
127 density=True,
128 cumulative=True,
129 histtype="stepfilled",
130 alpha=0.2,
131 )
132 ax.plot([0.0, h0ul], [0.0, 1], "k--")
133 ax.set_xlim((0.0, h0ul))
134 ax.set_ylim((0.0, 1.0))
135 ax.set_xlabel("h_0")
136 ax.set_ylabel("Cumulative probability")
137 fig.savefig("h0samps.png")
138 print("Saved plot to 'h0samps.png'")
139 exit_code = 1
140 break
141
142 # clean up per-run temporary files
143 for fs in (outfile, outfile_SNR, outfile_Znoise):
144 os.remove(fs)
145
146# clean up temporary files
147for fs in (priorf, parf, datafile):
148 os.remove(fs)
149
150sys.exit(exit_code)