Coverage for bilby/core/sampler/dynesty.py: 73%

524 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import datetime 

2import inspect 

3import os 

4import sys 

5import time 

6import warnings 

7 

8import numpy as np 

9from pandas import DataFrame 

10 

11from ..result import rejection_sample 

12from ..utils import ( 

13 check_directory_exists_and_if_not_mkdir, 

14 latex_plot_format, 

15 logger, 

16 safe_file_dump, 

17) 

18from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper 

19 

20 

21def _set_sampling_kwargs(args): 

22 nact, maxmcmc, proposals, naccept = args 

23 _SamplingContainer.nact = nact 

24 _SamplingContainer.maxmcmc = maxmcmc 

25 _SamplingContainer.proposals = proposals 

26 _SamplingContainer.naccept = naccept 

27 

28 

29def _prior_transform_wrapper(theta): 

30 """Wrapper to the prior transformation. Needed for multiprocessing.""" 

31 from .base_sampler import _sampling_convenience_dump 

32 

33 return _sampling_convenience_dump.priors.rescale( 

34 _sampling_convenience_dump.search_parameter_keys, theta 

35 ) 

36 

37 

38def _log_likelihood_wrapper(theta): 

39 """Wrapper to the log likelihood. Needed for multiprocessing.""" 

40 from .base_sampler import _sampling_convenience_dump 

41 

42 if _sampling_convenience_dump.priors.evaluate_constraints( 

43 { 

44 key: theta[ii] 

45 for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys) 

46 } 

47 ): 

48 params = { 

49 key: t 

50 for key, t in zip(_sampling_convenience_dump.search_parameter_keys, theta) 

51 } 

52 _sampling_convenience_dump.likelihood.parameters.update(params) 

53 if _sampling_convenience_dump.use_ratio: 

54 return _sampling_convenience_dump.likelihood.log_likelihood_ratio() 

55 else: 

56 return _sampling_convenience_dump.likelihood.log_likelihood() 

57 else: 

58 return np.nan_to_num(-np.inf) 

59 

60 

61class Dynesty(NestedSampler): 

62 """ 

63 bilby wrapper of `dynesty.NestedSampler` 

64 (https://dynesty.readthedocs.io/en/latest/) 

65 

66 All positional and keyword arguments (i.e., the args and kwargs) passed to 

67 `run_sampler` will be propagated to `dynesty.NestedSampler`, see 

68 documentation for that class for further help. Under Other Parameters below, 

69 we list commonly used kwargs and the Bilby defaults. 

70 

71 Parameters 

72 ========== 

73 likelihood: likelihood.Likelihood 

74 A object with a log_l method 

75 priors: bilby.core.prior.PriorDict, dict 

76 Priors to be used in the search. 

77 This has attributes for each parameter to be sampled. 

78 outdir: str, optional 

79 Name of the output directory 

80 label: str, optional 

81 Naming scheme of the output files 

82 use_ratio: bool, optional 

83 Switch to set whether or not you want to use the log-likelihood ratio 

84 or just the log-likelihood 

85 plot: bool, optional 

86 Switch to set whether or not you want to create traceplots 

87 skip_import_verification: bool 

88 Skips the check if the sampler is installed if true. This is 

89 only advisable for testing environments 

90 print_method: str ('tqdm') 

91 The method to use for printing. The options are: 

92 - 'tqdm': use a `tqdm` `pbar`, this is the default. 

93 - 'interval-$TIME': print to `stdout` every `$TIME` seconds, 

94 e.g., 'interval-10' prints every ten seconds, this does not print every iteration 

95 - else: print to `stdout` at every iteration 

96 exit_code: int 

97 The code which the same exits on if it hasn't finished sampling 

98 check_point: bool, 

99 If true, use check pointing. 

100 check_point_plot: bool, 

101 If true, generate a trace plot along with the check-point 

102 check_point_delta_t: float (600) 

103 The minimum checkpoint period (in seconds). Should the run be 

104 interrupted, it can be resumed from the last checkpoint. 

105 n_check_point: int, optional (None) 

106 The number of steps to take before checking whether to check_point. 

107 resume: bool 

108 If true, resume run from checkpoint (if available) 

109 maxmcmc: int (5000) 

110 The maximum length of the MCMC exploration to find a new point 

111 nact: int (2) 

112 The number of autocorrelation lengths for MCMC exploration. 

113 For use with the :code:`act-walk` and :code:`rwalk` sample methods. 

114 See the dynesty guide in the Bilby docs for more details. 

115 naccept: int (60) 

116 The expected number of accepted steps for MCMC exploration when using 

117 the :code:`acceptance-walk` sampling method. 

118 rejection_sample_posterior: bool (True) 

119 Whether to form the posterior by rejection sampling the nested samples. 

120 If False, the nested samples are resampled with repetition. This was 

121 the default behaviour in :code:`Bilby<=1.4.1` and leads to 

122 non-independent samples being produced. 

123 proposals: iterable (None) 

124 The proposal methods to use during MCMC. This can be some combination 

125 of :code:`"diff", "volumetric"`. See the dynesty guide in the Bilby docs 

126 for more details. default=:code:`["diff"]`. 

127 rstate: numpy.random.Generator (None) 

128 Instance of a numpy random generator for generating random numbers. 

129 Also see :code:`seed` in 'Other Parameters'. 

130 

131 Other Parameters 

132 ================ 

133 nlive: int, (1000) 

134 The number of live points, note this can also equivalently be given as 

135 one of [nlive, nlives, n_live_points, npoints] 

136 bound: {'live', 'live-multi', 'none', 'single', 'multi', 'balls', 'cubes'}, ('live') 

137 Method used to select new points 

138 sample: {'act-walk', 'acceptance-walk', 'unif', 'rwalk', 'slice', 

139 'rslice', 'hslice', 'rwalk_dynesty'}, ('act-walk') 

140 Method used to sample uniformly within the likelihood constraints, 

141 conditioned on the provided bounds 

142 walks: int (100) 

143 Number of walks taken if using the dynesty implemented sample methods 

144 Note that the default `walks` in dynesty itself is 25, although using 

145 `ndim * 10` can be a reasonable rule of thumb for new problems. 

146 For :code:`sample="act-walk"` and :code:`sample="rwalk"` this parameter 

147 has no impact on the sampling. 

148 dlogz: float, (0.1) 

149 Stopping criteria 

150 seed: int (None) 

151 Use to seed the random number generator if :code:`rstate` is not 

152 specified. 

153 """ 

154 

155 sampler_name = "dynesty" 

156 sampling_seed_key = "seed" 

157 

158 @property 

159 def _dynesty_init_kwargs(self): 

160 params = inspect.signature(self.sampler_init).parameters 

161 kwargs = { 

162 key: param.default 

163 for key, param in params.items() 

164 if param.default != param.empty 

165 } 

166 kwargs["sample"] = "act-walk" 

167 kwargs["bound"] = "live" 

168 kwargs["update_interval"] = 600 

169 kwargs["facc"] = 0.2 

170 return kwargs 

171 

172 @property 

173 def _dynesty_sampler_kwargs(self): 

174 params = inspect.signature(self.sampler_class.run_nested).parameters 

175 kwargs = { 

176 key: param.default 

177 for key, param in params.items() 

178 if param.default != param.empty 

179 } 

180 kwargs["save_bounds"] = False 

181 if "dlogz" in kwargs: 

182 kwargs["dlogz"] = 0.1 

183 return kwargs 

184 

185 @property 

186 def default_kwargs(self): 

187 kwargs = self._dynesty_init_kwargs 

188 kwargs.update(self._dynesty_sampler_kwargs) 

189 kwargs["seed"] = None 

190 return kwargs 

191 

192 def __init__( 

193 self, 

194 likelihood, 

195 priors, 

196 outdir="outdir", 

197 label="label", 

198 use_ratio=False, 

199 plot=False, 

200 skip_import_verification=False, 

201 check_point=True, 

202 check_point_plot=True, 

203 n_check_point=None, 

204 check_point_delta_t=600, 

205 resume=True, 

206 nestcheck=False, 

207 exit_code=130, 

208 print_method="tqdm", 

209 maxmcmc=5000, 

210 nact=2, 

211 naccept=60, 

212 rejection_sample_posterior=True, 

213 proposals=None, 

214 **kwargs, 

215 ): 

216 self.nact = nact 

217 self.naccept = naccept 

218 self.maxmcmc = maxmcmc 

219 self.proposals = proposals 

220 self.print_method = print_method 

221 self._translate_kwargs(kwargs) 

222 super(Dynesty, self).__init__( 

223 likelihood=likelihood, 

224 priors=priors, 

225 outdir=outdir, 

226 label=label, 

227 use_ratio=use_ratio, 

228 plot=plot, 

229 skip_import_verification=skip_import_verification, 

230 exit_code=exit_code, 

231 **kwargs, 

232 ) 

233 self.n_check_point = n_check_point 

234 self.check_point = check_point 

235 self.check_point_plot = check_point_plot 

236 self.resume = resume 

237 self.rejection_sample_posterior = rejection_sample_posterior 

238 self._apply_dynesty_boundaries("periodic") 

239 self._apply_dynesty_boundaries("reflective") 

240 

241 self.nestcheck = nestcheck 

242 

243 if self.n_check_point is None: 

244 self.n_check_point = ( 

245 10 

246 if np.isnan(self._log_likelihood_eval_time) 

247 else max( 

248 int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10 

249 ) 

250 ) 

251 self.check_point_delta_t = check_point_delta_t 

252 logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s") 

253 

254 self.resume_file = f"{self.outdir}/{self.label}_resume.pickle" 

255 self.sampling_time = datetime.timedelta() 

256 self.pbar = None 

257 

258 @property 

259 def sampler_function_kwargs(self): 

260 return {key: self.kwargs[key] for key in self._dynesty_sampler_kwargs} 

261 

262 @property 

263 def sampler_init_kwargs(self): 

264 return {key: self.kwargs[key] for key in self._dynesty_init_kwargs} 

265 

266 def _translate_kwargs(self, kwargs): 

267 kwargs = super()._translate_kwargs(kwargs) 

268 if "nlive" not in kwargs: 

269 for equiv in self.npoints_equiv_kwargs: 

270 if equiv in kwargs: 

271 kwargs["nlive"] = kwargs.pop(equiv) 

272 if "print_progress" not in kwargs: 

273 if "verbose" in kwargs: 

274 kwargs["print_progress"] = kwargs.pop("verbose") 

275 if "walks" not in kwargs: 

276 for equiv in self.walks_equiv_kwargs: 

277 if equiv in kwargs: 

278 kwargs["walks"] = kwargs.pop(equiv) 

279 if "queue_size" not in kwargs: 

280 for equiv in self.npool_equiv_kwargs: 

281 if equiv in kwargs: 

282 kwargs["queue_size"] = kwargs.pop(equiv) 

283 if "seed" in kwargs: 

284 seed = kwargs.get("seed") 

285 if "rstate" not in kwargs: 

286 kwargs["rstate"] = np.random.default_rng(seed) 

287 else: 

288 logger.warning( 

289 "Kwargs contain both 'rstate' and 'seed', ignoring 'seed'." 

290 ) 

291 

292 def _verify_kwargs_against_default_kwargs(self): 

293 if not self.kwargs["walks"]: 

294 self.kwargs["walks"] = 100 

295 if self.kwargs["print_func"] is None: 

296 self.kwargs["print_func"] = self._print_func 

297 if "interval" in self.print_method: 

298 self._last_print_time = datetime.datetime.now() 

299 self._print_interval = datetime.timedelta( 

300 seconds=float(self.print_method.split("-")[1]) 

301 ) 

302 Sampler._verify_kwargs_against_default_kwargs(self) 

303 

304 @classmethod 

305 def get_expected_outputs(cls, outdir=None, label=None): 

306 """Get lists of the expected outputs directories and files. 

307 

308 These are used by :code:`bilby_pipe` when transferring files via HTCondor. 

309 

310 Parameters 

311 ---------- 

312 outdir : str 

313 The output directory. 

314 label : str 

315 The label for the run. 

316 

317 Returns 

318 ------- 

319 list 

320 List of file names. 

321 list 

322 List of directory names. Will always be empty for dynesty. 

323 """ 

324 filenames = [] 

325 for kind in ["resume", "dynesty"]: 

326 filename = os.path.join(outdir, f"{label}_{kind}.pickle") 

327 filenames.append(filename) 

328 return filenames, [] 

329 

330 def _print_func( 

331 self, 

332 results, 

333 niter, 

334 ncall=None, 

335 dlogz=None, 

336 stop_val=None, 

337 nbatch=None, 

338 logl_min=-np.inf, 

339 logl_max=np.inf, 

340 *args, 

341 **kwargs, 

342 ): 

343 """Replacing status update for dynesty.result.print_func""" 

344 if "interval" in self.print_method: 

345 _time = datetime.datetime.now() 

346 if _time - self._last_print_time < self._print_interval: 

347 return 

348 else: 

349 self._last_print_time = _time 

350 

351 # Add time in current run to overall sampling time 

352 total_time = self.sampling_time + _time - self.start_time 

353 

354 # Remove fractional seconds 

355 total_time_str = str(total_time).split(".")[0] 

356 

357 # Extract results at the current iteration. 

358 loglstar = results.loglstar 

359 delta_logz = results.delta_logz 

360 logz = results.logz 

361 logzvar = results.logzvar 

362 nc = results.nc 

363 bounditer = results.bounditer 

364 eff = results.eff 

365 

366 # Adjusting outputs for printing. 

367 if delta_logz > 1e6: 

368 delta_logz = np.inf 

369 if 0.0 <= logzvar <= 1e6: 

370 logzerr = np.sqrt(logzvar) 

371 else: 

372 logzerr = np.nan 

373 if logz <= -1e6: 

374 logz = -np.inf 

375 if loglstar <= -1e6: 

376 loglstar = -np.inf 

377 

378 if self.use_ratio: 

379 key = "logz-ratio" 

380 else: 

381 key = "logz" 

382 

383 # Constructing output. 

384 string = list() 

385 string.append(f"bound:{bounditer:d}") 

386 string.append(f"nc:{nc:3d}") 

387 string.append(f"ncall:{ncall:.1e}") 

388 string.append(f"eff:{eff:0.1f}%") 

389 string.append(f"{key}={logz:0.2f}+/-{logzerr:0.2f}") 

390 if nbatch is not None: 

391 string.append(f"batch:{nbatch}") 

392 if logl_min > -np.inf: 

393 string.append(f"logl:{logl_min:.1f} < {loglstar:.1f} < {logl_max:.1f}") 

394 if dlogz is not None: 

395 string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}") 

396 else: 

397 string.append(f"stop:{stop_val:6.3f}") 

398 string = " ".join(string) 

399 

400 if self.print_method == "tqdm": 

401 self.pbar.set_postfix_str(string, refresh=False) 

402 self.pbar.update(niter - self.pbar.n) 

403 else: 

404 print(f"{niter}it [{total_time_str} {string}]", file=sys.stdout, flush=True) 

405 

406 def _apply_dynesty_boundaries(self, key): 

407 # The periodic kwargs passed into dynesty allows the parameters to 

408 # wander out of the bounds, this includes both periodic and reflective. 

409 # these are then handled in the prior_transform 

410 selected = list() 

411 for ii, param in enumerate(self.search_parameter_keys): 

412 if self.priors[param].boundary == key: 

413 logger.debug(f"Setting {key} boundary for {param}") 

414 selected.append(ii) 

415 if len(selected) == 0: 

416 selected = None 

417 self.kwargs[key] = selected 

418 

419 def nestcheck_data(self, out_file): 

420 import nestcheck.data_processing 

421 

422 ns_run = nestcheck.data_processing.process_dynesty_run(out_file) 

423 nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle" 

424 safe_file_dump(ns_run, nestcheck_result, "pickle") 

425 

426 @property 

427 def nlive(self): 

428 return self.kwargs["nlive"] 

429 

430 @property 

431 def sampler_init(self): 

432 from dynesty import NestedSampler 

433 

434 return NestedSampler 

435 

436 @property 

437 def sampler_class(self): 

438 from dynesty.sampler import Sampler 

439 

440 return Sampler 

441 

442 def _set_sampling_method(self): 

443 """ 

444 Resolve the sampling method and sampler to use from the provided 

445 :code:`bound` and :code:`sample` arguments. 

446 

447 This requires registering the :code:`bilby` specific methods in the 

448 appropriate locations within :code:`dynesty`. 

449 

450 Additionally, some combinations of bound/sample/proposals are not 

451 compatible and so we either warn the user or raise an error. 

452 """ 

453 import dynesty 

454 

455 _set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept)) 

456 

457 sample = self.kwargs["sample"] 

458 bound = self.kwargs["bound"] 

459 

460 if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [ 

461 "live", 

462 "live-multi", 

463 ]: 

464 logger.info( 

465 "Live-point based bound method requested with dynesty sample " 

466 f"'{sample}', overwriting to 'multi'" 

467 ) 

468 self.kwargs["bound"] = "multi" 

469 elif bound == "live": 

470 from .dynesty_utils import LivePointSampler 

471 

472 dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler 

473 elif bound == "live-multi": 

474 from .dynesty_utils import MultiEllipsoidLivePointSampler 

475 

476 dynesty.dynamicsampler._SAMPLERS[ 

477 "live-multi" 

478 ] = MultiEllipsoidLivePointSampler 

479 elif sample == "acceptance-walk": 

480 raise DynestySetupError( 

481 "bound must be set to live or live-multi for sample=acceptance-walk" 

482 ) 

483 elif self.proposals is None: 

484 logger.warning( 

485 "No proposals specified using dynesty sampling, defaulting " 

486 "to 'volumetric'." 

487 ) 

488 self.proposals = ["volumetric"] 

489 _SamplingContainer.proposals = self.proposals 

490 elif "diff" in self.proposals: 

491 raise DynestySetupError( 

492 "bound must be set to live or live-multi to use differential " 

493 "evolution proposals" 

494 ) 

495 

496 if sample == "rwalk": 

497 logger.info( 

498 f"Using the bilby-implemented {sample} sample method with ACT estimated walks. " 

499 f"An average of {2 * self.nact} steps will be accepted up to chain length " 

500 f"{self.maxmcmc}." 

501 ) 

502 from .dynesty_utils import AcceptanceTrackingRWalk 

503 

504 if self.kwargs["walks"] > self.maxmcmc: 

505 raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)") 

506 if self.nact < 1: 

507 raise DynestySetupError("Unable to run with nact < 1") 

508 AcceptanceTrackingRWalk.old_act = None 

509 dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk() 

510 elif sample == "acceptance-walk": 

511 logger.info( 

512 f"Using the bilby-implemented {sample} sampling with an average of " 

513 f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}" 

514 ) 

515 from .dynesty_utils import FixedRWalk 

516 

517 dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk() 

518 elif sample == "act-walk": 

519 logger.info( 

520 f"Using the bilby-implemented {sample} sampling tracking the " 

521 f"autocorrelation function and thinning by " 

522 f"{self.nact} with maximum length {self.nact * self.maxmcmc}" 

523 ) 

524 from .dynesty_utils import ACTTrackingRWalk 

525 

526 ACTTrackingRWalk._cache = list() 

527 dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk() 

528 elif sample == "rwalk_dynesty": 

529 sample = sample.strip("_dynesty") 

530 self.kwargs["sample"] = sample 

531 logger.info(f"Using the dynesty-implemented {sample} sample method") 

532 

533 @signal_wrapper 

534 def run_sampler(self): 

535 import dynesty 

536 

537 logger.info(f"Using dynesty version {dynesty.__version__}") 

538 

539 self._set_sampling_method() 

540 self._setup_pool() 

541 

542 if self.resume: 

543 self.resume = self.read_saved_state(continuing=True) 

544 

545 if self.resume: 

546 logger.info("Resume file successfully loaded.") 

547 else: 

548 if self.kwargs["live_points"] is None: 

549 self.kwargs["live_points"] = self.get_initial_points_from_prior( 

550 self.nlive 

551 ) 

552 self.kwargs["live_points"] = (*self.kwargs["live_points"], None) 

553 self.sampler = self.sampler_init( 

554 loglikelihood=_log_likelihood_wrapper, 

555 prior_transform=_prior_transform_wrapper, 

556 ndim=self.ndim, 

557 **self.sampler_init_kwargs, 

558 ) 

559 if self.print_method == "tqdm" and self.kwargs["print_progress"]: 

560 from tqdm.auto import tqdm 

561 

562 self.pbar = tqdm(file=sys.stdout, initial=self.sampler.it) 

563 

564 self.start_time = datetime.datetime.now() 

565 if self.check_point: 

566 out = self._run_external_sampler_with_checkpointing() 

567 else: 

568 out = self._run_external_sampler_without_checkpointing() 

569 self._update_sampling_time() 

570 

571 self._close_pool() 

572 

573 # Flushes the output to force a line break 

574 if self.pbar is not None: 

575 self.pbar = self.pbar.close() 

576 print("") 

577 

578 check_directory_exists_and_if_not_mkdir(self.outdir) 

579 

580 if self.nestcheck: 

581 self.nestcheck_data(out) 

582 

583 dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle" 

584 safe_file_dump(out, dynesty_result, "dill") 

585 

586 self._generate_result(out) 

587 self.result.sampling_time = self.sampling_time 

588 

589 return self.result 

590 

591 def _setup_pool(self): 

592 """ 

593 In addition to the usual steps, we need to set the sampling kwargs on 

594 every process. To make sure we get every process, run the kwarg setting 

595 more times than we have processes. 

596 """ 

597 super(Dynesty, self)._setup_pool() 

598 if self.pool is not None: 

599 args = ( 

600 [(self.nact, self.maxmcmc, self.proposals, self.naccept)] 

601 * self.npool 

602 * 10 

603 ) 

604 self.pool.map(_set_sampling_kwargs, args) 

605 

606 def _generate_result(self, out): 

607 """ 

608 Extract the information we need from the dynesty output. This includes 

609 the evidence, nested samples, run statistics. In addition, we generate 

610 the posterior samples from the nested samples. 

611 

612 Parameters 

613 ========== 

614 out: dynesty.result.Result 

615 The dynesty output. 

616 """ 

617 import dynesty 

618 from scipy.special import logsumexp 

619 

620 from ..utils.random import rng 

621 

622 logwts = out["logwt"] 

623 weights = np.exp(logwts - out["logz"][-1]) 

624 nested_samples = DataFrame(out.samples, columns=self.search_parameter_keys) 

625 nested_samples["weights"] = weights 

626 nested_samples["log_likelihood"] = out.logl 

627 self.result.nested_samples = nested_samples 

628 if self.rejection_sample_posterior: 

629 keep = weights > rng.uniform(0, max(weights), len(weights)) 

630 self.result.samples = out.samples[keep] 

631 self.result.log_likelihood_evaluations = out.logl[keep] 

632 logger.info( 

633 f"Rejection sampling nested samples to obtain {sum(keep)} posterior samples" 

634 ) 

635 else: 

636 self.result.samples = dynesty.utils.resample_equal(out.samples, weights) 

637 self.result.log_likelihood_evaluations = self.reorder_loglikelihoods( 

638 unsorted_loglikelihoods=out.logl, 

639 unsorted_samples=out.samples, 

640 sorted_samples=self.result.samples, 

641 ) 

642 logger.info("Resampling nested samples to posterior samples in place.") 

643 self.result.log_evidence = out.logz[-1] 

644 self.result.log_evidence_err = out.logzerr[-1] 

645 self.result.information_gain = out.information[-1] 

646 self.result.num_likelihood_evaluations = getattr(self.sampler, "ncall", 0) 

647 

648 logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2) 

649 neffsamples = int(np.exp(logneff)) 

650 self.result.meta_data["run_statistics"] = dict( 

651 nlikelihood=self.result.num_likelihood_evaluations, 

652 neffsamples=neffsamples, 

653 sampling_time_s=self.sampling_time.seconds, 

654 ncores=self.kwargs.get("queue_size", 1), 

655 ) 

656 self.kwargs["rstate"] = None 

657 

658 def _update_sampling_time(self): 

659 end_time = datetime.datetime.now() 

660 self.sampling_time += end_time - self.start_time 

661 self.start_time = end_time 

662 

663 def _run_external_sampler_without_checkpointing(self): 

664 logger.debug("Running sampler without checkpointing") 

665 self.sampler.run_nested(**self.sampler_function_kwargs) 

666 return self.sampler.results 

667 

668 def finalize_sampler_kwargs(self, sampler_kwargs): 

669 sampler_kwargs["maxcall"] = self.n_check_point 

670 sampler_kwargs["add_live"] = True 

671 

672 def _run_external_sampler_with_checkpointing(self): 

673 """ 

674 In order to access the checkpointing, we run the sampler for short 

675 periods of time (less than the checkpoint time) and if sufficient 

676 time has passed, write a checkpoint before continuing. To get the most 

677 informative checkpoint plots, the current live points are added to the 

678 chain of nested samples within dynesty and have to be removed before 

679 restarting the sampler. 

680 """ 

681 

682 logger.debug("Running sampler with checkpointing") 

683 

684 old_ncall = self.sampler.ncall 

685 sampler_kwargs = self.sampler_function_kwargs.copy() 

686 warnings.filterwarnings( 

687 "ignore", 

688 message="The sampling was stopped short due to maxiter/maxcall limit*", 

689 category=UserWarning, 

690 module="dynesty.sampler", 

691 ) 

692 while True: 

693 self.finalize_sampler_kwargs(sampler_kwargs) 

694 if getattr(self.sampler, "added_live", False): 

695 self.sampler._remove_live_points() 

696 self.sampler.run_nested(**sampler_kwargs) 

697 if self.sampler.ncall == old_ncall: 

698 break 

699 old_ncall = self.sampler.ncall 

700 

701 if os.path.isfile(self.resume_file): 

702 last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file) 

703 else: 

704 last_checkpoint_s = ( 

705 datetime.datetime.now() - self.start_time 

706 ).total_seconds() 

707 if last_checkpoint_s > self.check_point_delta_t: 

708 self.write_current_state() 

709 self.plot_current_state() 

710 if getattr(self.sampler, "added_live", False): 

711 self.sampler._remove_live_points() 

712 

713 self.sampler.run_nested(**sampler_kwargs) 

714 self.write_current_state() 

715 self.plot_current_state() 

716 return self.sampler.results 

717 

718 def _remove_checkpoint(self): 

719 """Remove checkpointed state""" 

720 if os.path.isfile(self.resume_file): 

721 os.remove(self.resume_file) 

722 

723 def read_saved_state(self, continuing=False): 

724 """ 

725 Read a pickled saved state of the sampler to disk. 

726 

727 If the live points are present and the run is continuing 

728 they are removed. 

729 The random state must be reset, as this isn't saved by the pickle. 

730 `nqueue` is set to a negative number to trigger the queue to be 

731 refilled before the first iteration. 

732 The previous run time is set to self. 

733 

734 Parameters 

735 ========== 

736 continuing: bool 

737 Whether the run is continuing or terminating, if True, the loaded 

738 state is mostly written back to disk. 

739 """ 

740 import dill 

741 from dynesty import __version__ as dynesty_version 

742 

743 from ... import __version__ as bilby_version 

744 

745 versions = dict(bilby=bilby_version, dynesty=dynesty_version) 

746 if os.path.isfile(self.resume_file): 

747 logger.info(f"Reading resume file {self.resume_file}") 

748 with open(self.resume_file, "rb") as file: 

749 try: 

750 sampler = dill.load(file) 

751 except EOFError: 

752 sampler = None 

753 

754 if not hasattr(sampler, "versions"): 

755 logger.warning( 

756 f"The resume file {self.resume_file} is corrupted or " 

757 "the version of bilby has changed between runs. This " 

758 "resume file will be ignored." 

759 ) 

760 return False 

761 version_warning = ( 

762 "The {code} version has changed between runs. " 

763 "This may cause unpredictable behaviour and/or failure. " 

764 "Old version = {old}, new version = {new}." 

765 ) 

766 for code in versions: 

767 if not versions[code] == sampler.versions.get(code, None): 

768 logger.warning( 

769 version_warning.format( 

770 code=code, 

771 old=sampler.versions.get(code, "None"), 

772 new=versions[code], 

773 ) 

774 ) 

775 del sampler.versions 

776 self.sampler = sampler 

777 if getattr(self.sampler, "added_live", False) and continuing: 

778 self.sampler._remove_live_points() 

779 self.sampler.nqueue = -1 

780 self.start_time = self.sampler.kwargs.pop("start_time") 

781 self.sampling_time = self.sampler.kwargs.pop("sampling_time") 

782 self.sampler.queue_size = self.kwargs["queue_size"] 

783 self.sampler.pool = self.pool 

784 if self.pool is not None: 

785 self.sampler.M = self.pool.map 

786 else: 

787 self.sampler.M = map 

788 return True 

789 else: 

790 logger.info(f"Resume file {self.resume_file} does not exist.") 

791 return False 

792 

793 def write_current_state_and_exit(self, signum=None, frame=None): 

794 if self.pbar is not None: 

795 self.pbar = self.pbar.close() 

796 super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame) 

797 

798 def write_current_state(self): 

799 """ 

800 Write the current state of the sampler to disk. 

801 

802 The sampler is pickle dumped using `dill`. 

803 The sampling time is also stored to get the full CPU time for the run. 

804 

805 The check of whether the sampler is picklable is to catch an error 

806 when using pytest. Hopefully, this message won't be triggered during 

807 normal running. 

808 """ 

809 

810 import dill 

811 from dynesty import __version__ as dynesty_version 

812 

813 from ... import __version__ as bilby_version 

814 

815 if getattr(self, "sampler", None) is None: 

816 # Sampler not initialized, not able to write current state 

817 return 

818 

819 check_directory_exists_and_if_not_mkdir(self.outdir) 

820 if hasattr(self, "start_time"): 

821 self._update_sampling_time() 

822 self.sampler.kwargs["sampling_time"] = self.sampling_time 

823 self.sampler.kwargs["start_time"] = self.start_time 

824 self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version) 

825 self.sampler.pool = None 

826 self.sampler.M = map 

827 if dill.pickles(self.sampler): 

828 safe_file_dump(self.sampler, self.resume_file, dill) 

829 logger.info(f"Written checkpoint file {self.resume_file}") 

830 else: 

831 logger.warning( 

832 "Cannot write pickle resume file! " 

833 "Job will not resume if interrupted." 

834 ) 

835 self.sampler.pool = self.pool 

836 if self.sampler.pool is not None: 

837 self.sampler.M = self.sampler.pool.map 

838 

839 def dump_samples_to_dat(self): 

840 """ 

841 Save the current posterior samples to a space-separated plain-text 

842 file. These are unbiased posterior samples, however, there will not 

843 be many of them until the analysis is nearly over. 

844 """ 

845 sampler = self.sampler 

846 ln_weights = sampler.saved_logwt - sampler.saved_logz[-1] 

847 

848 weights = np.exp(ln_weights) 

849 samples = rejection_sample(np.array(sampler.saved_v), weights) 

850 nsamples = len(samples) 

851 

852 # If we don't have enough samples, don't dump them 

853 if nsamples < 100: 

854 return 

855 

856 filename = f"{self.outdir}/{self.label}_samples.dat" 

857 logger.info(f"Writing {nsamples} current samples to {filename}") 

858 

859 df = DataFrame(samples, columns=self.search_parameter_keys) 

860 df.to_csv(filename, index=False, header=True, sep=" ") 

861 

862 def plot_current_state(self): 

863 """ 

864 Make diagonstic plots of the history and current state of the sampler. 

865 

866 These plots are a mixture of :code:`dynesty` implemented run and trace 

867 plots and our custom stats plot. We also make a copy of the trace plot 

868 using the unit hypercube samples to reflect the internal state of the 

869 sampler. 

870 

871 Any errors during plotting should be handled so that sampling can 

872 continue. 

873 """ 

874 if self.check_point_plot: 

875 import dynesty.plotting as dyplot 

876 import matplotlib.pyplot as plt 

877 

878 labels = [label.replace("_", " ") for label in self.search_parameter_keys] 

879 try: 

880 filename = f"{self.outdir}/{self.label}_checkpoint_trace.png" 

881 fig = dyplot.traceplot(self.sampler.results, labels=labels)[0] 

882 fig.tight_layout() 

883 fig.savefig(filename) 

884 except ( 

885 RuntimeError, 

886 np.linalg.linalg.LinAlgError, 

887 ValueError, 

888 OverflowError, 

889 ) as e: 

890 logger.warning(e) 

891 logger.warning("Failed to create dynesty state plot at checkpoint") 

892 except Exception as e: 

893 logger.warning( 

894 f"Unexpected error {e} in dynesty plotting. " 

895 "Please report at git.ligo.org/lscsoft/bilby/-/issues" 

896 ) 

897 finally: 

898 plt.close("all") 

899 try: 

900 filename = f"{self.outdir}/{self.label}_checkpoint_trace_unit.png" 

901 from copy import deepcopy 

902 

903 from dynesty.utils import results_substitute 

904 

905 temp = deepcopy(self.sampler.results) 

906 temp = results_substitute(temp, dict(samples=temp["samples_u"])) 

907 fig = dyplot.traceplot(temp, labels=labels)[0] 

908 fig.tight_layout() 

909 fig.savefig(filename) 

910 except ( 

911 RuntimeError, 

912 np.linalg.linalg.LinAlgError, 

913 ValueError, 

914 OverflowError, 

915 ) as e: 

916 logger.warning(e) 

917 logger.warning("Failed to create dynesty unit state plot at checkpoint") 

918 except Exception as e: 

919 logger.warning( 

920 f"Unexpected error {e} in dynesty plotting. " 

921 "Please report at git.ligo.org/lscsoft/bilby/-/issues" 

922 ) 

923 finally: 

924 plt.close("all") 

925 try: 

926 filename = f"{self.outdir}/{self.label}_checkpoint_run.png" 

927 fig, _ = dyplot.runplot( 

928 self.sampler.results, logplot=False, use_math_text=False 

929 ) 

930 fig.tight_layout() 

931 plt.savefig(filename) 

932 except ( 

933 RuntimeError, 

934 np.linalg.linalg.LinAlgError, 

935 ValueError, 

936 OverflowError, 

937 ) as e: 

938 logger.warning(e) 

939 logger.warning("Failed to create dynesty run plot at checkpoint") 

940 except Exception as e: 

941 logger.warning( 

942 f"Unexpected error {e} in dynesty plotting. " 

943 "Please report at git.ligo.org/lscsoft/bilby/-/issues" 

944 ) 

945 finally: 

946 plt.close("all") 

947 try: 

948 filename = f"{self.outdir}/{self.label}_checkpoint_stats.png" 

949 fig, _ = dynesty_stats_plot(self.sampler) 

950 fig.tight_layout() 

951 plt.savefig(filename) 

952 except (RuntimeError, ValueError, OverflowError) as e: 

953 logger.warning(e) 

954 logger.warning("Failed to create dynesty stats plot at checkpoint") 

955 except DynestySetupError: 

956 logger.debug("Cannot create Dynesty stats plot with dynamic sampler.") 

957 except Exception as e: 

958 logger.warning( 

959 f"Unexpected error {e} in dynesty plotting. " 

960 "Please report at git.ligo.org/lscsoft/bilby/-/issues" 

961 ) 

962 finally: 

963 plt.close("all") 

964 

965 def _run_test(self): 

966 """Run the sampler very briefly as a sanity test that it works.""" 

967 import pandas as pd 

968 

969 self._set_sampling_method() 

970 self._setup_pool() 

971 self.sampler = self.sampler_init( 

972 loglikelihood=_log_likelihood_wrapper, 

973 prior_transform=_prior_transform_wrapper, 

974 ndim=self.ndim, 

975 **self.sampler_init_kwargs, 

976 ) 

977 sampler_kwargs = self.sampler_function_kwargs.copy() 

978 sampler_kwargs["maxiter"] = 2 

979 

980 if self.print_method == "tqdm" and self.kwargs["print_progress"]: 

981 from tqdm.auto import tqdm 

982 

983 self.pbar = tqdm(file=sys.stdout, initial=self.sampler.it) 

984 self.sampler.run_nested(**sampler_kwargs) 

985 self._close_pool() 

986 

987 if self.pbar is not None: 

988 self.pbar = self.pbar.close() 

989 print("") 

990 N = 100 

991 self.result.samples = pd.DataFrame(self.priors.sample(N))[ 

992 self.search_parameter_keys 

993 ].values 

994 self.result.nested_samples = self.result.samples 

995 self.result.log_likelihood_evaluations = np.ones(N) 

996 self.result.log_evidence = 1 

997 self.result.log_evidence_err = 0.1 

998 

999 return self.result 

1000 

1001 def prior_transform(self, theta): 

1002 """Prior transform method that is passed into the external sampler. 

1003 cube we map this back to [0, 1]. 

1004 

1005 Parameters 

1006 ========== 

1007 theta: list 

1008 List of sampled values on a unit interval 

1009 

1010 Returns 

1011 ======= 

1012 list: Properly rescaled sampled values 

1013 

1014 """ 

1015 return self.priors.rescale(self._search_parameter_keys, theta) 

1016 

1017 

1018@latex_plot_format 

1019def dynesty_stats_plot(sampler): 

1020 """ 

1021 Plot diagnostic statistics from a dynesty run 

1022 

1023 The plotted quantities per iteration are: 

1024 

1025 - nc: the number of likelihood calls 

1026 - scale: the number of accepted MCMC steps if using :code:`bound="live"` 

1027 or :code:`bound="live-multi"`, otherwise, the scale applied to the MCMC 

1028 steps 

1029 - lifetime: the number of iterations a point stays in the live set 

1030 

1031 There is also a histogram of the lifetime compared with the theoretical 

1032 distribution. To avoid edge effects, we discard the first 6 * nlive 

1033 

1034 Parameters 

1035 ---------- 

1036 sampler: dynesty.sampler.Sampler 

1037 The sampler object containing the run history. 

1038 

1039 Returns 

1040 ------- 

1041 fig: matplotlib.pyplot.figure.Figure 

1042 Figure handle for the new plot 

1043 axs: matplotlib.pyplot.axes.Axes 

1044 Axes handles for the new plot 

1045 

1046 """ 

1047 import matplotlib.pyplot as plt 

1048 from scipy.stats import geom, ks_1samp 

1049 

1050 fig, axs = plt.subplots(nrows=4, figsize=(8, 8)) 

1051 data = sampler.saved_run.D 

1052 for ax, name in zip(axs, ["nc", "scale"]): 

1053 ax.plot(data[name], color="blue") 

1054 ax.set_ylabel(name.title()) 

1055 lifetimes = np.arange(len(data["it"])) - data["it"] 

1056 axs[-2].set_ylabel("Lifetime") 

1057 if not hasattr(sampler, "nlive"): 

1058 raise DynestySetupError("Cannot make stats plot for dynamic sampler.") 

1059 nlive = sampler.nlive 

1060 burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive)) 

1061 if len(data["it"]) > burn + sampler.nlive: 

1062 axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey") 

1063 axs[-2].plot( 

1064 np.arange(burn, len(lifetimes) - nlive), 

1065 lifetimes[burn:-nlive], 

1066 color="blue", 

1067 ) 

1068 axs[-2].plot( 

1069 np.arange(len(lifetimes) - nlive, len(lifetimes)), 

1070 lifetimes[-nlive:], 

1071 color="red", 

1072 ) 

1073 lifetimes = lifetimes[burn:-nlive] 

1074 ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf) 

1075 axs[-1].hist( 

1076 lifetimes, 

1077 bins=np.linspace(0, 6 * nlive, 60), 

1078 histtype="step", 

1079 density=True, 

1080 color="blue", 

1081 label=f"p value = {ks_result.pvalue:.3f}", 

1082 ) 

1083 axs[-1].plot( 

1084 np.arange(1, 6 * nlive), 

1085 geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)), 

1086 color="red", 

1087 ) 

1088 axs[-1].set_xlim(0, 6 * nlive) 

1089 axs[-1].legend() 

1090 axs[-1].set_yscale("log") 

1091 else: 

1092 axs[-2].plot( 

1093 np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey" 

1094 ) 

1095 axs[-2].plot( 

1096 np.arange(len(lifetimes) - nlive, len(lifetimes)), 

1097 lifetimes[-nlive:], 

1098 color="red", 

1099 ) 

1100 axs[-2].set_yscale("log") 

1101 axs[-2].set_xlabel("Iteration") 

1102 axs[-1].set_xlabel("Lifetime") 

1103 return fig, axs 

1104 

1105 

1106class DynestySetupError(Exception): 

1107 pass