Coverage for parallel_bilby/analysis/main.py: 89%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Module to run parallel bilby using MPI
3"""
4import datetime
5import json
6import os
7import pickle
8import time
10import bilby
11import numpy as np
12import pandas as pd
13from bilby.core.utils import logger
14from bilby.gw import conversion
15from nestcheck import data_processing
16from pandas import DataFrame
18from ..parser import create_analysis_parser, parse_analysis_args
19from ..schwimmbad_fast import MPIPoolFast as MPIPool
20from ..utils import get_cli_args, stdout_sampling_log
21from .analysis_run import AnalysisRun
22from .plotting import plot_current_state
23from .read_write import (
24 format_result,
25 read_saved_state,
26 write_current_state,
27 write_sample_dump,
28)
29from .sample_space import fill_sample
32def analysis_runner(
33 data_dump,
34 outdir=None,
35 label=None,
36 dynesty_sample="acceptance-walk",
37 nlive=5,
38 dynesty_bound="live",
39 walks=100,
40 proposals=None,
41 maxmcmc=5000,
42 naccept=60,
43 nact=2,
44 facc=0.5,
45 min_eff=10,
46 enlarge=1.5,
47 sampling_seed=0,
48 bilby_zero_likelihood_mode=False,
49 rejection_sample_posterior=True,
50 #
51 fast_mpi=False,
52 mpi_timing=False,
53 mpi_timing_interval=0,
54 check_point_deltaT=3600,
55 n_effective=np.inf,
56 dlogz=10,
57 do_not_save_bounds_in_resume=True,
58 n_check_point=1000,
59 max_its=1e10,
60 max_run_time=1e10,
61 rotate_checkpoints=False,
62 no_plot=False,
63 nestcheck=False,
64 result_format="hdf5",
65 **kwargs,
66):
67 """
68 API for running the analysis from Python instead of the command line.
69 It takes all the same options as the CLI, specified as keyword arguments.
71 Returns
72 -------
73 exit_reason: integer u
74 Used during testing, to determine the reason the code halted:
75 0 = run completed normally, based on convergence criteria
76 1 = reached max iterations
77 2 = reached max runtime
78 MPI worker tasks always return -1
80 """
82 # Initialise a run
83 run = AnalysisRun(
84 data_dump=data_dump,
85 outdir=outdir,
86 label=label,
87 dynesty_sample=dynesty_sample,
88 nlive=nlive,
89 dynesty_bound=dynesty_bound,
90 walks=walks,
91 maxmcmc=maxmcmc,
92 nact=nact,
93 naccept=naccept,
94 facc=facc,
95 min_eff=min_eff,
96 enlarge=enlarge,
97 sampling_seed=sampling_seed,
98 proposals=proposals,
99 bilby_zero_likelihood_mode=bilby_zero_likelihood_mode,
100 )
102 t0 = datetime.datetime.now()
103 sampling_time = 0
104 with MPIPool(
105 parallel_comms=fast_mpi,
106 time_mpi=mpi_timing,
107 timing_interval=mpi_timing_interval,
108 use_dill=True,
109 ) as pool:
110 if pool.is_master():
111 POOL_SIZE = pool.size
113 logger.info(f"sampling_keys={run.sampling_keys}")
114 if run.periodic:
115 logger.info(
116 f"Periodic keys: {[run.sampling_keys[ii] for ii in run.periodic]}"
117 )
118 if run.reflective:
119 logger.info(
120 f"Reflective keys: {[run.sampling_keys[ii] for ii in run.reflective]}"
121 )
122 logger.info("Using priors:")
123 for key in run.priors:
124 logger.info(f"{key}: {run.priors[key]}")
126 resume_file = f"{run.outdir}/{run.label}_checkpoint_resume.pickle"
127 samples_file = f"{run.outdir}/{run.label}_samples.dat"
129 sampler, sampling_time = read_saved_state(resume_file)
131 if sampler is False:
132 logger.info(f"Initializing sampling points with pool size={POOL_SIZE}")
133 live_points = run.get_initial_points_from_prior(pool)
134 logger.info(
135 f"Initialize NestedSampler with "
136 f"{json.dumps(run.init_sampler_kwargs, indent=1, sort_keys=True)}"
137 )
138 sampler = run.get_nested_sampler(live_points, pool, POOL_SIZE)
139 else:
140 # Reinstate the pool and map (not saved in the pickle)
141 logger.info(f"Read in resume file with sampling_time = {sampling_time}")
142 sampler.pool = pool
143 sampler.M = pool.map
144 sampler.loglikelihood.pool = pool
146 logger.info(
147 f"Starting sampling for job {run.label}, with pool size={POOL_SIZE} "
148 f"and check_point_deltaT={check_point_deltaT}"
149 )
151 sampler_kwargs = dict(
152 n_effective=n_effective,
153 dlogz=dlogz,
154 save_bounds=not do_not_save_bounds_in_resume,
155 )
156 logger.info(f"Run criteria: {json.dumps(sampler_kwargs)}")
158 run_time = 0
159 early_stop = False
161 for it, res in enumerate(sampler.sample(**sampler_kwargs)):
162 stdout_sampling_log(
163 results=res, niter=it, ncall=sampler.ncall, dlogz=dlogz
164 )
166 iteration_time = (datetime.datetime.now() - t0).total_seconds()
167 t0 = datetime.datetime.now()
169 sampling_time += iteration_time
170 run_time += iteration_time
172 if os.path.isfile(resume_file):
173 last_checkpoint_s = time.time() - os.path.getmtime(resume_file)
174 else:
175 last_checkpoint_s = np.inf
177 """
178 Criteria for writing checkpoints:
179 a) time since last checkpoint > check_point_deltaT
180 b) reached an integer multiple of n_check_point
181 c) reached max iterations
182 d) reached max runtime
183 """
185 if (
186 last_checkpoint_s > check_point_deltaT
187 or (it % n_check_point == 0 and it != 0)
188 or it == max_its
189 or run_time > max_run_time
190 ):
192 write_current_state(
193 sampler,
194 resume_file,
195 sampling_time,
196 rotate_checkpoints,
197 )
198 write_sample_dump(sampler, samples_file, run.sampling_keys)
199 if no_plot is False:
200 plot_current_state(
201 sampler, run.sampling_keys, run.outdir, run.label
202 )
204 if it == max_its:
205 exit_reason = 1
206 logger.info(
207 f"Max iterations ({it}) reached; stopping sampling (exit_reason={exit_reason})."
208 )
209 early_stop = True
210 break
212 if run_time > max_run_time:
213 exit_reason = 2
214 logger.info(
215 f"Max run time ({max_run_time}) reached; stopping sampling (exit_reason={exit_reason})."
216 )
217 early_stop = True
218 break
220 if not early_stop:
221 exit_reason = 0
222 # Adding the final set of live points.
223 for it_final, res in enumerate(sampler.add_live_points()):
224 pass
226 # Create a final checkpoint and set of plots
227 write_current_state(
228 sampler, resume_file, sampling_time, rotate_checkpoints
229 )
230 write_sample_dump(sampler, samples_file, run.sampling_keys)
231 if no_plot is False:
232 plot_current_state(
233 sampler, run.sampling_keys, run.outdir, run.label
234 )
236 sampling_time += (datetime.datetime.now() - t0).total_seconds()
238 out = sampler.results
240 if nestcheck is True:
241 logger.info("Creating nestcheck files")
242 ns_run = data_processing.process_dynesty_run(out)
243 nestcheck_path = os.path.join(run.outdir, "Nestcheck")
244 bilby.core.utils.check_directory_exists_and_if_not_mkdir(
245 nestcheck_path
246 )
247 nestcheck_result = f"{nestcheck_path}/{run.label}_nestcheck.pickle"
249 with open(nestcheck_result, "wb") as file_nest:
250 pickle.dump(ns_run, file_nest)
252 weights = np.exp(out["logwt"] - out["logz"][-1])
253 nested_samples = DataFrame(out.samples, columns=run.sampling_keys)
254 nested_samples["weights"] = weights
255 nested_samples["log_likelihood"] = out.logl
257 result = format_result(
258 run,
259 data_dump,
260 out,
261 weights,
262 nested_samples,
263 sampler_kwargs,
264 sampling_time,
265 )
267 posterior = conversion.fill_from_fixed_priors(
268 result.posterior, run.priors
269 )
271 logger.info(
272 "Generating posterior from marginalized parameters for"
273 f" nsamples={len(posterior)}, POOL={pool.size}"
274 )
275 fill_args = [
276 (ii, row, run.likelihood) for ii, row in posterior.iterrows()
277 ]
278 samples = pool.map(fill_sample, fill_args)
279 result.posterior = pd.DataFrame(samples)
281 logger.debug(
282 "Updating prior to the actual prior (undoing marginalization)"
283 )
284 for par, name in zip(
285 ["distance", "phase", "time"],
286 ["luminosity_distance", "phase", "geocent_time"],
287 ):
288 if getattr(run.likelihood, f"{par}_marginalization", False):
289 run.priors[name] = run.likelihood.priors[name]
290 result.priors = run.priors
292 result.posterior = result.posterior.applymap(
293 lambda x: x[0] if isinstance(x, list) else x
294 )
295 result.posterior = result.posterior.select_dtypes([np.number])
296 logger.info(
297 f"Saving result to {run.outdir}/{run.label}_result.{result_format}"
298 )
299 if result_format != "json": # json is saved by default
300 result.save_to_file(extension="json")
301 result.save_to_file(extension=result_format)
302 print(
303 f"Sampling time = {datetime.timedelta(seconds=result.sampling_time)}s"
304 )
305 print(f"Number of lnl calls = {result.num_likelihood_evaluations}")
306 print(result)
307 if no_plot is False:
308 result.plot_corner()
310 else:
311 exit_reason = -1
312 return exit_reason
315def main():
316 """
317 paralell_bilby_analysis entrypoint.
319 This function is a wrapper around analysis_runner(),
320 giving it a command line interface.
321 """
322 cli_args = get_cli_args()
324 # Parse command line arguments
325 analysis_parser = create_analysis_parser(sampler="dynesty")
326 input_args = parse_analysis_args(analysis_parser, cli_args=cli_args)
328 # Run the analysis
329 analysis_runner(**vars(input_args))