Coverage for parallel_bilby/analysis/main.py: 89%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

126 statements  

1""" 

2Module to run parallel bilby using MPI 

3""" 

4import datetime 

5import json 

6import os 

7import pickle 

8import time 

9 

10import bilby 

11import numpy as np 

12import pandas as pd 

13from bilby.core.utils import logger 

14from bilby.gw import conversion 

15from nestcheck import data_processing 

16from pandas import DataFrame 

17 

18from ..parser import create_analysis_parser, parse_analysis_args 

19from ..schwimmbad_fast import MPIPoolFast as MPIPool 

20from ..utils import get_cli_args, stdout_sampling_log 

21from .analysis_run import AnalysisRun 

22from .plotting import plot_current_state 

23from .read_write import ( 

24 format_result, 

25 read_saved_state, 

26 write_current_state, 

27 write_sample_dump, 

28) 

29from .sample_space import fill_sample 

30 

31 

32def analysis_runner( 

33 data_dump, 

34 outdir=None, 

35 label=None, 

36 dynesty_sample="acceptance-walk", 

37 nlive=5, 

38 dynesty_bound="live", 

39 walks=100, 

40 proposals=None, 

41 maxmcmc=5000, 

42 naccept=60, 

43 nact=2, 

44 facc=0.5, 

45 min_eff=10, 

46 enlarge=1.5, 

47 sampling_seed=0, 

48 bilby_zero_likelihood_mode=False, 

49 rejection_sample_posterior=True, 

50 # 

51 fast_mpi=False, 

52 mpi_timing=False, 

53 mpi_timing_interval=0, 

54 check_point_deltaT=3600, 

55 n_effective=np.inf, 

56 dlogz=10, 

57 do_not_save_bounds_in_resume=True, 

58 n_check_point=1000, 

59 max_its=1e10, 

60 max_run_time=1e10, 

61 rotate_checkpoints=False, 

62 no_plot=False, 

63 nestcheck=False, 

64 result_format="hdf5", 

65 **kwargs, 

66): 

67 """ 

68 API for running the analysis from Python instead of the command line. 

69 It takes all the same options as the CLI, specified as keyword arguments. 

70 

71 Returns 

72 ------- 

73 exit_reason: integer u 

74 Used during testing, to determine the reason the code halted: 

75 0 = run completed normally, based on convergence criteria 

76 1 = reached max iterations 

77 2 = reached max runtime 

78 MPI worker tasks always return -1 

79 

80 """ 

81 

82 # Initialise a run 

83 run = AnalysisRun( 

84 data_dump=data_dump, 

85 outdir=outdir, 

86 label=label, 

87 dynesty_sample=dynesty_sample, 

88 nlive=nlive, 

89 dynesty_bound=dynesty_bound, 

90 walks=walks, 

91 maxmcmc=maxmcmc, 

92 nact=nact, 

93 naccept=naccept, 

94 facc=facc, 

95 min_eff=min_eff, 

96 enlarge=enlarge, 

97 sampling_seed=sampling_seed, 

98 proposals=proposals, 

99 bilby_zero_likelihood_mode=bilby_zero_likelihood_mode, 

100 ) 

101 

102 t0 = datetime.datetime.now() 

103 sampling_time = 0 

104 with MPIPool( 

105 parallel_comms=fast_mpi, 

106 time_mpi=mpi_timing, 

107 timing_interval=mpi_timing_interval, 

108 use_dill=True, 

109 ) as pool: 

110 if pool.is_master(): 

111 POOL_SIZE = pool.size 

112 

113 logger.info(f"sampling_keys={run.sampling_keys}") 

114 if run.periodic: 

115 logger.info( 

116 f"Periodic keys: {[run.sampling_keys[ii] for ii in run.periodic]}" 

117 ) 

118 if run.reflective: 

119 logger.info( 

120 f"Reflective keys: {[run.sampling_keys[ii] for ii in run.reflective]}" 

121 ) 

122 logger.info("Using priors:") 

123 for key in run.priors: 

124 logger.info(f"{key}: {run.priors[key]}") 

125 

126 resume_file = f"{run.outdir}/{run.label}_checkpoint_resume.pickle" 

127 samples_file = f"{run.outdir}/{run.label}_samples.dat" 

128 

129 sampler, sampling_time = read_saved_state(resume_file) 

130 

131 if sampler is False: 

132 logger.info(f"Initializing sampling points with pool size={POOL_SIZE}") 

133 live_points = run.get_initial_points_from_prior(pool) 

134 logger.info( 

135 f"Initialize NestedSampler with " 

136 f"{json.dumps(run.init_sampler_kwargs, indent=1, sort_keys=True)}" 

137 ) 

138 sampler = run.get_nested_sampler(live_points, pool, POOL_SIZE) 

139 else: 

140 # Reinstate the pool and map (not saved in the pickle) 

141 logger.info(f"Read in resume file with sampling_time = {sampling_time}") 

142 sampler.pool = pool 

143 sampler.M = pool.map 

144 sampler.loglikelihood.pool = pool 

145 

146 logger.info( 

147 f"Starting sampling for job {run.label}, with pool size={POOL_SIZE} " 

148 f"and check_point_deltaT={check_point_deltaT}" 

149 ) 

150 

151 sampler_kwargs = dict( 

152 n_effective=n_effective, 

153 dlogz=dlogz, 

154 save_bounds=not do_not_save_bounds_in_resume, 

155 ) 

156 logger.info(f"Run criteria: {json.dumps(sampler_kwargs)}") 

157 

158 run_time = 0 

159 early_stop = False 

160 

161 for it, res in enumerate(sampler.sample(**sampler_kwargs)): 

162 stdout_sampling_log( 

163 results=res, niter=it, ncall=sampler.ncall, dlogz=dlogz 

164 ) 

165 

166 iteration_time = (datetime.datetime.now() - t0).total_seconds() 

167 t0 = datetime.datetime.now() 

168 

169 sampling_time += iteration_time 

170 run_time += iteration_time 

171 

172 if os.path.isfile(resume_file): 

173 last_checkpoint_s = time.time() - os.path.getmtime(resume_file) 

174 else: 

175 last_checkpoint_s = np.inf 

176 

177 """ 

178 Criteria for writing checkpoints: 

179 a) time since last checkpoint > check_point_deltaT 

180 b) reached an integer multiple of n_check_point 

181 c) reached max iterations 

182 d) reached max runtime 

183 """ 

184 

185 if ( 

186 last_checkpoint_s > check_point_deltaT 

187 or (it % n_check_point == 0 and it != 0) 

188 or it == max_its 

189 or run_time > max_run_time 

190 ): 

191 

192 write_current_state( 

193 sampler, 

194 resume_file, 

195 sampling_time, 

196 rotate_checkpoints, 

197 ) 

198 write_sample_dump(sampler, samples_file, run.sampling_keys) 

199 if no_plot is False: 

200 plot_current_state( 

201 sampler, run.sampling_keys, run.outdir, run.label 

202 ) 

203 

204 if it == max_its: 

205 exit_reason = 1 

206 logger.info( 

207 f"Max iterations ({it}) reached; stopping sampling (exit_reason={exit_reason})." 

208 ) 

209 early_stop = True 

210 break 

211 

212 if run_time > max_run_time: 

213 exit_reason = 2 

214 logger.info( 

215 f"Max run time ({max_run_time}) reached; stopping sampling (exit_reason={exit_reason})." 

216 ) 

217 early_stop = True 

218 break 

219 

220 if not early_stop: 

221 exit_reason = 0 

222 # Adding the final set of live points. 

223 for it_final, res in enumerate(sampler.add_live_points()): 

224 pass 

225 

226 # Create a final checkpoint and set of plots 

227 write_current_state( 

228 sampler, resume_file, sampling_time, rotate_checkpoints 

229 ) 

230 write_sample_dump(sampler, samples_file, run.sampling_keys) 

231 if no_plot is False: 

232 plot_current_state( 

233 sampler, run.sampling_keys, run.outdir, run.label 

234 ) 

235 

236 sampling_time += (datetime.datetime.now() - t0).total_seconds() 

237 

238 out = sampler.results 

239 

240 if nestcheck is True: 

241 logger.info("Creating nestcheck files") 

242 ns_run = data_processing.process_dynesty_run(out) 

243 nestcheck_path = os.path.join(run.outdir, "Nestcheck") 

244 bilby.core.utils.check_directory_exists_and_if_not_mkdir( 

245 nestcheck_path 

246 ) 

247 nestcheck_result = f"{nestcheck_path}/{run.label}_nestcheck.pickle" 

248 

249 with open(nestcheck_result, "wb") as file_nest: 

250 pickle.dump(ns_run, file_nest) 

251 

252 weights = np.exp(out["logwt"] - out["logz"][-1]) 

253 nested_samples = DataFrame(out.samples, columns=run.sampling_keys) 

254 nested_samples["weights"] = weights 

255 nested_samples["log_likelihood"] = out.logl 

256 

257 result = format_result( 

258 run, 

259 data_dump, 

260 out, 

261 weights, 

262 nested_samples, 

263 sampler_kwargs, 

264 sampling_time, 

265 ) 

266 

267 posterior = conversion.fill_from_fixed_priors( 

268 result.posterior, run.priors 

269 ) 

270 

271 logger.info( 

272 "Generating posterior from marginalized parameters for" 

273 f" nsamples={len(posterior)}, POOL={pool.size}" 

274 ) 

275 fill_args = [ 

276 (ii, row, run.likelihood) for ii, row in posterior.iterrows() 

277 ] 

278 samples = pool.map(fill_sample, fill_args) 

279 result.posterior = pd.DataFrame(samples) 

280 

281 logger.debug( 

282 "Updating prior to the actual prior (undoing marginalization)" 

283 ) 

284 for par, name in zip( 

285 ["distance", "phase", "time"], 

286 ["luminosity_distance", "phase", "geocent_time"], 

287 ): 

288 if getattr(run.likelihood, f"{par}_marginalization", False): 

289 run.priors[name] = run.likelihood.priors[name] 

290 result.priors = run.priors 

291 

292 result.posterior = result.posterior.applymap( 

293 lambda x: x[0] if isinstance(x, list) else x 

294 ) 

295 result.posterior = result.posterior.select_dtypes([np.number]) 

296 logger.info( 

297 f"Saving result to {run.outdir}/{run.label}_result.{result_format}" 

298 ) 

299 if result_format != "json": # json is saved by default 

300 result.save_to_file(extension="json") 

301 result.save_to_file(extension=result_format) 

302 print( 

303 f"Sampling time = {datetime.timedelta(seconds=result.sampling_time)}s" 

304 ) 

305 print(f"Number of lnl calls = {result.num_likelihood_evaluations}") 

306 print(result) 

307 if no_plot is False: 

308 result.plot_corner() 

309 

310 else: 

311 exit_reason = -1 

312 return exit_reason 

313 

314 

315def main(): 

316 """ 

317 paralell_bilby_analysis entrypoint. 

318 

319 This function is a wrapper around analysis_runner(), 

320 giving it a command line interface. 

321 """ 

322 cli_args = get_cli_args() 

323 

324 # Parse command line arguments 

325 analysis_parser = create_analysis_parser(sampler="dynesty") 

326 input_args = parse_analysis_args(analysis_parser, cli_args=cli_args) 

327 

328 # Run the analysis 

329 analysis_runner(**vars(input_args))