Coverage for bilby/core/sampler/ptemcee.py: 34%

520 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import copy 

2import datetime 

3import logging 

4import os 

5import time 

6from collections import namedtuple 

7 

8import numpy as np 

9import pandas as pd 

10 

11from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump 

12from .base_sampler import ( 

13 MCMCSampler, 

14 SamplerError, 

15 _sampling_convenience_dump, 

16 signal_wrapper, 

17) 

18 

19ConvergenceInputs = namedtuple( 

20 "ConvergenceInputs", 

21 [ 

22 "autocorr_c", 

23 "autocorr_tol", 

24 "autocorr_tau", 

25 "gradient_tau", 

26 "gradient_mean_log_posterior", 

27 "Q_tol", 

28 "safety", 

29 "burn_in_nact", 

30 "burn_in_fixed_discard", 

31 "mean_logl_frac", 

32 "thin_by_nact", 

33 "nsamples", 

34 "ignore_keys_for_tau", 

35 "min_tau", 

36 "niterations_per_check", 

37 ], 

38) 

39 

40 

41class Ptemcee(MCMCSampler): 

42 """bilby wrapper ptemcee (https://github.com/willvousden/ptemcee) 

43 

44 All positional and keyword arguments (i.e., the args and kwargs) passed to 

45 `run_sampler` will be propagated to `ptemcee.Sampler`, see 

46 documentation for that class for further help. Under Other Parameters, we 

47 list commonly used kwargs and the bilby defaults. 

48 

49 Parameters 

50 ---------- 

51 nsamples: int, (5000) 

52 The requested number of samples. Note, in cases where the 

53 autocorrelation parameter is difficult to measure, it is possible to 

54 end up with more than nsamples. 

55 burn_in_nact, thin_by_nact: int, (50, 1) 

56 The number of burn-in autocorrelation times to discard and the thin-by 

57 factor. Increasing burn_in_nact increases the time required for burn-in. 

58 Increasing thin_by_nact increases the time required to obtain nsamples. 

59 burn_in_fixed_discard: int (0) 

60 A fixed number of samples to discard for burn-in 

61 mean_logl_frac: float, (0.0.1) 

62 The maximum fractional change the mean log-likelihood to accept 

63 autocorr_tol: int, (50) 

64 The minimum number of autocorrelation times needed to trust the 

65 estimate of the autocorrelation time. 

66 autocorr_c: int, (5) 

67 The step size for the window search used by emcee.autocorr.integrated_time 

68 safety: int, (1) 

69 A multiplicative factor for the estimated autocorrelation. Useful for 

70 cases where non-convergence can be observed by eye but the automated 

71 tools are failing. 

72 autocorr_tau: int, (1) 

73 The number of autocorrelation times to use in assessing if the 

74 autocorrelation time is stable. 

75 gradient_tau: float, (0.1) 

76 The maximum (smoothed) local gradient of the ACT estimate to allow. 

77 This ensures the ACT estimate is stable before finishing sampling. 

78 gradient_mean_log_posterior: float, (0.1) 

79 The maximum (smoothed) local gradient of the logliklilhood to allow. 

80 This ensures the ACT estimate is stable before finishing sampling. 

81 Q_tol: float (1.01) 

82 The maximum between-chain to within-chain tolerance allowed (akin to 

83 the Gelman-Rubin statistic). 

84 min_tau: int, (1) 

85 A minimum tau (autocorrelation time) to accept. 

86 check_point_delta_t: float, (600) 

87 The period with which to checkpoint (in seconds). 

88 threads: int, (1) 

89 If threads > 1, a MultiPool object is setup and used. 

90 exit_code: int, (77) 

91 The code on which the sampler exits. 

92 store_walkers: bool (False) 

93 If true, store the unthinned, unburnt chains in the result. Note, this 

94 is not recommended for cases where tau is large. 

95 ignore_keys_for_tau: str 

96 A pattern used to ignore keys in estimating the autocorrelation time. 

97 pos0: str, list, np.ndarray, dict 

98 If a string, one of "prior" or "minimize". For "prior", the initial 

99 positions of the sampler are drawn from the sampler. If "minimize", 

100 a scipy.optimize step is applied to all parameters a number of times. 

101 The walkers are then initialized from the range of values obtained. 

102 If a list, for the keys in the list the optimization step is applied, 

103 otherwise the initial points are drawn from the prior. 

104 If a :code:`numpy` array the shape should be 

105 :code:`(ntemps, nwalkers, ndim)`. 

106 If a :code:`dict`, this should be a dictionary with keys matching the 

107 :code:`search_parameter_keys`. Each entry should be an array with 

108 shape :code:`(ntemps, nwalkers)`. 

109 

110 niterations_per_check: int (5) 

111 The number of iteration steps to take before checking ACT. This 

112 effectively pre-thins the chains. Larger values reduce the per-eval 

113 timing due to improved efficiency. But, if it is made too large the 

114 pre-thinning may be overly aggressive effectively wasting compute-time. 

115 If you see tau=1, then niterations_per_check is likely too large. 

116 

117 

118 Other Parameters 

119 ---------------- 

120 nwalkers: int, (200) 

121 The number of walkers 

122 nsteps: int, (100) 

123 The number of steps to take 

124 ntemps: int (10) 

125 The number of temperatures used by ptemcee 

126 Tmax: float 

127 The maximum temperature 

128 

129 """ 

130 

131 sampler_name = "ptemcee" 

132 # Arguments used by ptemcee 

133 default_kwargs = dict( 

134 ntemps=10, 

135 nwalkers=100, 

136 Tmax=None, 

137 betas=None, 

138 a=2.0, 

139 adaptation_lag=10000, 

140 adaptation_time=100, 

141 random=None, 

142 adapt=False, 

143 swap_ratios=False, 

144 ) 

145 

146 def __init__( 

147 self, 

148 likelihood, 

149 priors, 

150 outdir="outdir", 

151 label="label", 

152 use_ratio=False, 

153 check_point_plot=True, 

154 skip_import_verification=False, 

155 resume=True, 

156 nsamples=5000, 

157 burn_in_nact=50, 

158 burn_in_fixed_discard=0, 

159 mean_logl_frac=0.01, 

160 thin_by_nact=0.5, 

161 autocorr_tol=50, 

162 autocorr_c=5, 

163 safety=1, 

164 autocorr_tau=1, 

165 gradient_tau=0.1, 

166 gradient_mean_log_posterior=0.1, 

167 Q_tol=1.02, 

168 min_tau=1, 

169 check_point_delta_t=600, 

170 threads=1, 

171 exit_code=77, 

172 plot=False, 

173 store_walkers=False, 

174 ignore_keys_for_tau=None, 

175 pos0="prior", 

176 niterations_per_check=5, 

177 log10beta_min=None, 

178 verbose=True, 

179 **kwargs, 

180 ): 

181 super(Ptemcee, self).__init__( 

182 likelihood=likelihood, 

183 priors=priors, 

184 outdir=outdir, 

185 label=label, 

186 use_ratio=use_ratio, 

187 plot=plot, 

188 skip_import_verification=skip_import_verification, 

189 exit_code=exit_code, 

190 **kwargs, 

191 ) 

192 

193 self.nwalkers = self.sampler_init_kwargs["nwalkers"] 

194 self.ntemps = self.sampler_init_kwargs["ntemps"] 

195 self.max_steps = 500 

196 

197 # Checkpointing inputs 

198 self.resume = resume 

199 self.check_point_delta_t = check_point_delta_t 

200 self.check_point_plot = check_point_plot 

201 self.resume_file = f"{self.outdir}/{self.label}_checkpoint_resume.pickle" 

202 

203 # Store convergence checking inputs in a named tuple 

204 convergence_inputs_dict = dict( 

205 autocorr_c=autocorr_c, 

206 autocorr_tol=autocorr_tol, 

207 autocorr_tau=autocorr_tau, 

208 safety=safety, 

209 burn_in_nact=burn_in_nact, 

210 burn_in_fixed_discard=burn_in_fixed_discard, 

211 mean_logl_frac=mean_logl_frac, 

212 thin_by_nact=thin_by_nact, 

213 gradient_tau=gradient_tau, 

214 gradient_mean_log_posterior=gradient_mean_log_posterior, 

215 Q_tol=Q_tol, 

216 nsamples=nsamples, 

217 ignore_keys_for_tau=ignore_keys_for_tau, 

218 min_tau=min_tau, 

219 niterations_per_check=niterations_per_check, 

220 ) 

221 self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict) 

222 logger.info(f"Using convergence inputs: {self.convergence_inputs}") 

223 

224 # Check if threads was given as an equivalent arg 

225 if threads == 1: 

226 for equiv in self.npool_equiv_kwargs: 

227 if equiv in kwargs: 

228 threads = kwargs.pop(equiv) 

229 

230 # Store threads 

231 self.threads = threads 

232 

233 # Misc inputs 

234 self.store_walkers = store_walkers 

235 self.pos0 = pos0 

236 

237 self._periodic = [ 

238 self.priors[key].boundary == "periodic" 

239 for key in self.search_parameter_keys 

240 ] 

241 self.priors.sample() 

242 self._minima = np.array( 

243 [self.priors[key].minimum for key in self.search_parameter_keys] 

244 ) 

245 self._range = ( 

246 np.array([self.priors[key].maximum for key in self.search_parameter_keys]) 

247 - self._minima 

248 ) 

249 

250 self.log10beta_min = log10beta_min 

251 if self.log10beta_min is not None: 

252 betas = np.logspace(0, self.log10beta_min, self.ntemps) 

253 logger.warning(f"Using betas {betas}") 

254 self.kwargs["betas"] = betas 

255 self.verbose = verbose 

256 

257 self.iteration = 0 

258 self.chain_array = self.get_zero_chain_array() 

259 self.log_likelihood_array = self.get_zero_array() 

260 self.log_posterior_array = self.get_zero_array() 

261 self.beta_list = list() 

262 self.tau_list = list() 

263 self.tau_list_n = list() 

264 self.Q_list = list() 

265 self.time_per_check = list() 

266 

267 self.nburn = np.nan 

268 self.thin = np.nan 

269 self.tau_int = np.nan 

270 self.nsamples_effective = 0 

271 self.discard = 0 

272 

273 @property 

274 def sampler_function_kwargs(self): 

275 """Kwargs passed to samper.sampler()""" 

276 keys = ["adapt", "swap_ratios"] 

277 return {key: self.kwargs[key] for key in keys} 

278 

279 @property 

280 def sampler_init_kwargs(self): 

281 """Kwargs passed to initialize ptemcee.Sampler()""" 

282 return { 

283 key: value 

284 for key, value in self.kwargs.items() 

285 if key not in self.sampler_function_kwargs 

286 } 

287 

288 def _translate_kwargs(self, kwargs): 

289 """Translate kwargs""" 

290 kwargs = super()._translate_kwargs(kwargs) 

291 if "nwalkers" not in kwargs: 

292 for equiv in self.nwalkers_equiv_kwargs: 

293 if equiv in kwargs: 

294 kwargs["nwalkers"] = kwargs.pop(equiv) 

295 

296 def get_pos0_from_prior(self): 

297 """Draw the initial positions from the prior 

298 

299 Returns 

300 ------- 

301 pos0: list 

302 The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim) 

303 

304 """ 

305 logger.info("Generating pos0 samples") 

306 return np.array( 

307 [ 

308 [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] 

309 for _ in range(self.kwargs["ntemps"]) 

310 ] 

311 ) 

312 

313 def get_pos0_from_minimize(self, minimize_list=None): 

314 """Draw the initial positions using an initial minimization step 

315 

316 See pos0 in the class initialization for details. 

317 

318 Returns 

319 ------- 

320 pos0: list 

321 The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim) 

322 

323 """ 

324 

325 from scipy.optimize import minimize 

326 

327 from ..utils.random import rng 

328 

329 # Set up the minimize list: keys not in this list will have initial 

330 # positions drawn from the prior 

331 if minimize_list is None: 

332 minimize_list = self.search_parameter_keys 

333 pos0 = np.zeros((self.kwargs["ntemps"], self.kwargs["nwalkers"], self.ndim)) 

334 else: 

335 pos0 = np.array(self.get_pos0_from_prior()) 

336 

337 logger.info(f"Attempting to set pos0 for {minimize_list} from minimize") 

338 

339 likelihood_copy = copy.copy(self.likelihood) 

340 

341 def neg_log_like(params): 

342 """Internal function to minimize""" 

343 likelihood_copy.parameters.update( 

344 {key: val for key, val in zip(minimize_list, params)} 

345 ) 

346 try: 

347 return -likelihood_copy.log_likelihood() 

348 except RuntimeError: 

349 return +np.inf 

350 

351 # Bounds used in the minimization 

352 bounds = [ 

353 (self.priors[key].minimum, self.priors[key].maximum) 

354 for key in minimize_list 

355 ] 

356 

357 # Run the minimization step several times to get a range of values 

358 trials = 0 

359 success = [] 

360 while True: 

361 draw = self.priors.sample() 

362 likelihood_copy.parameters.update(draw) 

363 x0 = [draw[key] for key in minimize_list] 

364 res = minimize( 

365 neg_log_like, x0, bounds=bounds, method="L-BFGS-B", tol=1e-15 

366 ) 

367 if res.success: 

368 success.append(res.x) 

369 if trials > 100: 

370 raise SamplerError("Unable to set pos0 from minimize") 

371 if len(success) >= 10: 

372 break 

373 

374 # Initialize positions from the range of values 

375 success = np.array(success) 

376 for i, key in enumerate(minimize_list): 

377 pos0_min = np.min(success[:, i]) 

378 pos0_max = np.max(success[:, i]) 

379 logger.info(f"Initialize {key} walkers from {pos0_min}->{pos0_max}") 

380 j = self.search_parameter_keys.index(key) 

381 pos0[:, :, j] = rng.uniform( 

382 pos0_min, 

383 pos0_max, 

384 size=(self.kwargs["ntemps"], self.kwargs["nwalkers"]), 

385 ) 

386 return pos0 

387 

388 def get_pos0_from_array(self): 

389 if self.pos0.shape != (self.ntemps, self.nwalkers, self.ndim): 

390 raise ValueError( 

391 "Shape of starting array should be (ntemps, nwalkers, ndim). " 

392 f"In this case that is ({self.ntemps}, {self.nwalkers}, " 

393 f"{self.ndim}), got {self.pos0.shape}" 

394 ) 

395 else: 

396 return self.pos0 

397 

398 def get_pos0_from_dict(self): 

399 """ 

400 Initialize the starting points from a passed dictionary. 

401 

402 The :code:`pos0` passed to the :code:`Sampler` should be a dictionary 

403 with keys matching the :code:`search_parameter_keys`. 

404 Each entry should have shape :code:`(ntemps, nwalkers)`. 

405 """ 

406 pos0 = np.array([self.pos0[key] for key in self.search_parameter_keys]) 

407 self.pos0 = np.moveaxis(pos0, 0, -1) 

408 return self.get_pos0_from_array() 

409 

410 def setup_sampler(self): 

411 """Either initialize the sampler or read in the resume file""" 

412 import ptemcee 

413 

414 if ptemcee.__version__ == "1.0.0": 

415 # This is a very ugly hack to support numpy>=1.24 

416 ptemcee.sampler.np.float = float 

417 

418 if ( 

419 os.path.isfile(self.resume_file) 

420 and os.path.getsize(self.resume_file) 

421 and self.resume is True 

422 ): 

423 import dill 

424 

425 logger.info(f"Resume data {self.resume_file} found") 

426 with open(self.resume_file, "rb") as file: 

427 data = dill.load(file) 

428 

429 # Extract the check-point data 

430 self.sampler = data["sampler"] 

431 self.iteration = data["iteration"] 

432 self.chain_array = data["chain_array"] 

433 self.log_likelihood_array = data["log_likelihood_array"] 

434 self.log_posterior_array = data["log_posterior_array"] 

435 self.pos0 = data["pos0"] 

436 self.beta_list = data["beta_list"] 

437 self.sampler._betas = np.array(self.beta_list[-1]) 

438 self.tau_list = data["tau_list"] 

439 self.tau_list_n = data["tau_list_n"] 

440 self.Q_list = data["Q_list"] 

441 self.time_per_check = data["time_per_check"] 

442 

443 # Initialize the pool 

444 self.sampler.pool = self.pool 

445 self.sampler.threads = self.threads 

446 

447 logger.info(f"Resuming from previous run with time={self.iteration}") 

448 

449 else: 

450 # Initialize the PTSampler 

451 if self.threads == 1: 

452 self.sampler = ptemcee.Sampler( 

453 dim=self.ndim, 

454 logl=self.log_likelihood, 

455 logp=self.log_prior, 

456 **self.sampler_init_kwargs, 

457 ) 

458 else: 

459 self.sampler = ptemcee.Sampler( 

460 dim=self.ndim, 

461 logl=do_nothing_function, 

462 logp=do_nothing_function, 

463 threads=self.threads, 

464 **self.sampler_init_kwargs, 

465 ) 

466 

467 self.sampler._likeprior = LikePriorEvaluator() 

468 

469 # Initialize storing results 

470 self.iteration = 0 

471 self.chain_array = self.get_zero_chain_array() 

472 self.log_likelihood_array = self.get_zero_array() 

473 self.log_posterior_array = self.get_zero_array() 

474 self.beta_list = list() 

475 self.tau_list = list() 

476 self.tau_list_n = list() 

477 self.Q_list = list() 

478 self.time_per_check = list() 

479 self.pos0 = self.get_pos0() 

480 

481 return self.sampler 

482 

483 def get_zero_chain_array(self): 

484 return np.zeros((self.nwalkers, self.max_steps, self.ndim)) 

485 

486 def get_zero_array(self): 

487 return np.zeros((self.ntemps, self.nwalkers, self.max_steps)) 

488 

489 def get_pos0(self): 

490 """Master logic for setting pos0""" 

491 if isinstance(self.pos0, str) and self.pos0.lower() == "prior": 

492 return self.get_pos0_from_prior() 

493 elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize": 

494 return self.get_pos0_from_minimize() 

495 elif isinstance(self.pos0, list): 

496 return self.get_pos0_from_minimize(minimize_list=self.pos0) 

497 elif isinstance(self.pos0, np.ndarray): 

498 return self.get_pos0_from_array() 

499 elif isinstance(self.pos0, dict): 

500 return self.get_pos0_from_dict() 

501 else: 

502 raise SamplerError(f"pos0={self.pos0} not implemented") 

503 

504 def _close_pool(self): 

505 if getattr(self.sampler, "pool", None) is not None: 

506 self.sampler.pool = None 

507 if "pool" in self.result.sampler_kwargs: 

508 del self.result.sampler_kwargs["pool"] 

509 super(Ptemcee, self)._close_pool() 

510 

511 @signal_wrapper 

512 def run_sampler(self): 

513 self._setup_pool() 

514 sampler = self.setup_sampler() 

515 

516 t0 = datetime.datetime.now() 

517 logger.info("Starting to sample") 

518 

519 while True: 

520 for pos0, log_posterior, log_likelihood in sampler.sample( 

521 self.pos0, 

522 storechain=False, 

523 iterations=self.convergence_inputs.niterations_per_check, 

524 **self.sampler_function_kwargs, 

525 ): 

526 pos0[:, :, self._periodic] = ( 

527 np.mod( 

528 pos0[:, :, self._periodic] - self._minima[self._periodic], 

529 self._range[self._periodic], 

530 ) 

531 + self._minima[self._periodic] 

532 ) 

533 

534 if self.iteration == self.chain_array.shape[1]: 

535 self.chain_array = np.concatenate( 

536 (self.chain_array, self.get_zero_chain_array()), axis=1 

537 ) 

538 self.log_likelihood_array = np.concatenate( 

539 (self.log_likelihood_array, self.get_zero_array()), axis=2 

540 ) 

541 self.log_posterior_array = np.concatenate( 

542 (self.log_posterior_array, self.get_zero_array()), axis=2 

543 ) 

544 

545 self.pos0 = pos0 

546 

547 self.chain_array[:, self.iteration, :] = pos0[0, :, :] 

548 self.log_likelihood_array[:, :, self.iteration] = log_likelihood 

549 self.log_posterior_array[:, :, self.iteration] = log_posterior 

550 self.mean_log_posterior = np.mean( 

551 self.log_posterior_array[:, :, : self.iteration], axis=1 

552 ) 

553 

554 # (nwalkers, ntemps, iterations) 

555 # so mean_log_posterior is shaped (nwalkers, iterations) 

556 

557 # Calculate time per iteration 

558 self.time_per_check.append((datetime.datetime.now() - t0).total_seconds()) 

559 t0 = datetime.datetime.now() 

560 

561 self.iteration += 1 

562 

563 # Calculate minimum iteration step to discard 

564 minimum_iteration = get_minimum_stable_itertion( 

565 self.mean_log_posterior, frac=self.convergence_inputs.mean_logl_frac 

566 ) 

567 logger.debug(f"Minimum iteration = {minimum_iteration}") 

568 

569 # Calculate the maximum discard number 

570 discard_max = np.max( 

571 [self.convergence_inputs.burn_in_fixed_discard, minimum_iteration] 

572 ) 

573 

574 if self.iteration > discard_max + self.nwalkers: 

575 # If we have taken more than nwalkers steps after the discard 

576 # then set the discard 

577 self.discard = discard_max 

578 else: 

579 # If haven't discard everything (avoid initialisation bias) 

580 logger.debug("Too few steps to calculate convergence") 

581 self.discard = self.iteration 

582 

583 ( 

584 stop, 

585 self.nburn, 

586 self.thin, 

587 self.tau_int, 

588 self.nsamples_effective, 

589 ) = check_iteration( 

590 self.iteration, 

591 self.chain_array[:, self.discard : self.iteration, :], 

592 sampler, 

593 self.convergence_inputs, 

594 self.search_parameter_keys, 

595 self.time_per_check, 

596 self.beta_list, 

597 self.tau_list, 

598 self.tau_list_n, 

599 self.Q_list, 

600 self.mean_log_posterior, 

601 verbose=self.verbose, 

602 ) 

603 

604 if stop: 

605 logger.info("Finished sampling") 

606 break 

607 

608 # If a checkpoint is due, checkpoint 

609 if os.path.isfile(self.resume_file): 

610 last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file) 

611 else: 

612 last_checkpoint_s = np.sum(self.time_per_check) 

613 

614 if last_checkpoint_s > self.check_point_delta_t: 

615 self.write_current_state(plot=self.check_point_plot) 

616 

617 # Run a final checkpoint to update the plots and samples 

618 self.write_current_state(plot=self.check_point_plot) 

619 

620 # Get 0-likelihood samples and store in the result 

621 self.result.samples = self.chain_array[ 

622 :, self.discard + self.nburn : self.iteration : self.thin, : 

623 ].reshape((-1, self.ndim)) 

624 loglikelihood = self.log_likelihood_array[ 

625 0, :, self.discard + self.nburn : self.iteration : self.thin 

626 ] # nwalkers, nsteps 

627 self.result.log_likelihood_evaluations = loglikelihood.reshape((-1)) 

628 

629 if self.store_walkers: 

630 self.result.walkers = self.sampler.chain 

631 self.result.nburn = self.nburn 

632 self.result.discard = self.discard 

633 

634 log_evidence, log_evidence_err = compute_evidence( 

635 sampler, 

636 self.log_likelihood_array, 

637 self.outdir, 

638 self.label, 

639 self.discard, 

640 self.nburn, 

641 self.thin, 

642 self.iteration, 

643 ) 

644 self.result.log_evidence = log_evidence 

645 self.result.log_evidence_err = log_evidence_err 

646 

647 self.result.sampling_time = datetime.timedelta( 

648 seconds=np.sum(self.time_per_check) 

649 ) 

650 

651 self._close_pool() 

652 

653 return self.result 

654 

655 def write_current_state(self, plot=True): 

656 check_directory_exists_and_if_not_mkdir(self.outdir) 

657 checkpoint( 

658 self.iteration, 

659 self.outdir, 

660 self.label, 

661 self.nsamples_effective, 

662 self.sampler, 

663 self.discard, 

664 self.nburn, 

665 self.thin, 

666 self.search_parameter_keys, 

667 self.resume_file, 

668 self.log_likelihood_array, 

669 self.log_posterior_array, 

670 self.chain_array, 

671 self.pos0, 

672 self.beta_list, 

673 self.tau_list, 

674 self.tau_list_n, 

675 self.Q_list, 

676 self.time_per_check, 

677 ) 

678 

679 if plot: 

680 try: 

681 # Generate the walkers plot diagnostic 

682 plot_walkers( 

683 self.chain_array[:, : self.iteration, :], 

684 self.nburn, 

685 self.thin, 

686 self.search_parameter_keys, 

687 self.outdir, 

688 self.label, 

689 self.discard, 

690 ) 

691 except Exception as e: 

692 logger.info(f"Walkers plot failed with exception {e}") 

693 

694 try: 

695 # Generate the tau plot diagnostic if DEBUG 

696 if logger.level < logging.INFO: 

697 plot_tau( 

698 self.tau_list_n, 

699 self.tau_list, 

700 self.search_parameter_keys, 

701 self.outdir, 

702 self.label, 

703 self.tau_int, 

704 self.convergence_inputs.autocorr_tau, 

705 ) 

706 except Exception as e: 

707 logger.info(f"tau plot failed with exception {e}") 

708 

709 try: 

710 plot_mean_log_posterior( 

711 self.mean_log_posterior, 

712 self.outdir, 

713 self.label, 

714 ) 

715 except Exception as e: 

716 logger.info(f"mean_logl plot failed with exception {e}") 

717 

718 @classmethod 

719 def get_expected_outputs(cls, outdir=None, label=None): 

720 """Get lists of the expected outputs directories and files. 

721 

722 These are used by :code:`bilby_pipe` when transferring files via HTCondor. 

723 

724 Parameters 

725 ---------- 

726 outdir : str 

727 The output directory. 

728 label : str 

729 The label for the run. 

730 

731 Returns 

732 ------- 

733 list 

734 List of file names. 

735 list 

736 List of directory names. Will always be empty for ptemcee. 

737 """ 

738 filenames = [f"{outdir}/{label}_checkpoint_resume.pickle"] 

739 return filenames, [] 

740 

741 

742def get_minimum_stable_itertion(mean_array, frac, nsteps_min=10): 

743 nsteps = mean_array.shape[1] 

744 if nsteps < nsteps_min: 

745 return 0 

746 

747 min_it = 0 

748 for x in mean_array: 

749 maxl = np.max(x) 

750 fracdiff = (maxl - x) / np.abs(maxl) 

751 idxs = fracdiff < frac 

752 if np.sum(idxs) > 0: 

753 min_it = np.max([min_it, np.min(np.arange(len(idxs))[idxs])]) 

754 return min_it 

755 

756 

757def check_iteration( 

758 iteration, 

759 samples, 

760 sampler, 

761 convergence_inputs, 

762 search_parameter_keys, 

763 time_per_check, 

764 beta_list, 

765 tau_list, 

766 tau_list_n, 

767 gelman_rubin_list, 

768 mean_log_posterior, 

769 verbose=True, 

770): 

771 """Per-iteration logic to calculate the convergence check. 

772 

773 To check convergence, this function does the following: 

774 1. Calculate the autocorrelation time (tau) for each dimension for each walker, 

775 corresponding to those dimensions in search_parameter_keys that aren't 

776 specifically excluded in ci.ignore_keys_for_tau. 

777 a. Store the average tau for each dimension, averaged over each walker. 

778 2. Calculate the Gelman-Rubin statistic (see `get_Q_convergence`), measuring 

779 the convergence of the ensemble of walkers. 

780 3. Calculate the number of effective samples; we aggregate the total number 

781 of burned-in samples (amongst all walkers), divided by a multiple of the 

782 current maximum average autocorrelation time. Tuned by `ci.burn_in_nact` 

783 and `ci.thin_by_nact`. 

784 4. If the Gelman-Rubin statistic < `ci.Q_tol` and `ci.nsamples` < the 

785 number of effective samples, we say that our ensemble is converged, 

786 setting `converged = True`. 

787 5. For some number of the latest steps (set by the autocorrelation time 

788 and the GRAD_WINDOW_LENGTH parameter), we find the maxmium gradient 

789 of the autocorrelation time over all of our dimensions, over all walkers 

790 (autocorrelation time is already averaged over walkers) and the maximum 

791 value of the gradient of the mean log posterior over iterations, over 

792 all walkers. 

793 6. If the maximum gradient in tau is less than `ci.gradient_tau` and the 

794 maximum gradient in the mean log posterior is less than 

795 `ci.gradient_mean_log_posterior`, we set `tau_usable = True`. 

796 7. If both `converged` and `tau_usable` are true, we return `stop = True`, 

797 indicating that our ensemble is converged + burnt in on this 

798 iteration. 

799 8. Also prints progress! (see `print_progress`) 

800 

801 Notes 

802 ----- 

803 The gradient of tau is computed with a Savgol-Filter, over windows in 

804 sample number of length `GRAD_WINDOW_LENGTH`. This value must be an odd integer. 

805 For `ndim > 3`, we calculate this as the nearest odd integer to ndim. 

806 For `ndim <= 3`, we calculate this as the nearest odd integer to nwalkers, as 

807 typically a much larger window length than polynomial order (default 2) leads 

808 to more stable smoothing. 

809 

810 Parameters 

811 ---------- 

812 iteration: int 

813 Number indexing the current iteration, at which we are checking 

814 convergence. 

815 samples: np.ndarray 

816 Array of ensemble MCMC samples, shaped like (number of walkers, number 

817 of MCMC steps, number of dimensions). 

818 sampler: bilby.core.sampler.Ptemcee 

819 Bilby Ptemcee sampler object; in particular, this function uses the list 

820 of walker temperatures stored in `sampler.betas`. 

821 convergence_inputs: bilby.core.sampler.ptemcee.ConvergenceInputs 

822 A named tuple of the convergence checking inputs 

823 search_parameter_keys: list 

824 A list of the search parameter keys 

825 time_per_check, tau_list, tau_list_n: list 

826 Lists used for tracking the run 

827 beta_list: list 

828 List of floats storing the walker inverse temperatures. 

829 tau_list: list 

830 List of average autocorrelation times for each dimension, averaged 

831 over walkers, at each checked iteration. So, an effective shape 

832 of (number of iterations so far, number of dimensions). 

833 tau_list_n: list 

834 List of iteration numbers, enumerating the first "axis" of tau_list. 

835 E.g. if tau_list_n[1] = 5, this means that the list found at 

836 tau_list[1] was calculated on iteration number 5. 

837 gelman_rubin_list: list (floats) 

838 list of values of the Gelman-Rubin statistic; the value calculated 

839 in this call of check_iteration is appended to the gelman_rubin_list. 

840 mean_log_posterior: np.ndarray 

841 Float array shaped like (number of walkers, number of MCMC steps), 

842 with the log of the posterior, averaged over the dimensions. 

843 verbose: bool 

844 Whether to print the output 

845 

846 Returns 

847 ------- 

848 stop: bool 

849 A boolean flag, True if the stopping criteria has been met 

850 burn: int 

851 The number of burn-in steps to discard 

852 thin: int 

853 The thin-by factor to apply 

854 tau_int: int 

855 The integer estimated ACT 

856 nsamples_effective: int 

857 The effective number of samples after burning and thinning 

858 """ 

859 

860 ci = convergence_inputs 

861 

862 nwalkers, nsteps, ndim = samples.shape 

863 tau_array = calculate_tau_array(samples, search_parameter_keys, ci) 

864 tau = np.max(np.mean(tau_array, axis=0)) 

865 

866 # Apply multiplicitive safety factor 

867 tau = ci.safety * tau 

868 

869 # Store for convergence checking and plotting 

870 beta_list.append(list(sampler.betas)) 

871 tau_list.append(list(np.mean(tau_array, axis=0))) 

872 tau_list_n.append(iteration) 

873 

874 gelman_rubin_statistic = get_Q_convergence(samples) 

875 gelman_rubin_list.append(gelman_rubin_statistic) 

876 

877 if np.isnan(tau) or np.isinf(tau): 

878 if verbose: 

879 print_progress( 

880 iteration, 

881 sampler, 

882 time_per_check, 

883 np.nan, 

884 np.nan, 

885 np.nan, 

886 np.nan, 

887 np.nan, 

888 False, 

889 convergence_inputs, 

890 gelman_rubin_statistic, 

891 ) 

892 return False, np.nan, np.nan, np.nan, np.nan 

893 

894 # Convert to an integer 

895 tau_int = int(np.ceil(tau)) 

896 

897 # Calculate the effective number of samples available 

898 nburn = int(ci.burn_in_nact * tau_int) 

899 thin = int(np.max([1, ci.thin_by_nact * tau_int])) 

900 samples_per_check = nwalkers / thin 

901 nsamples_effective = int(nwalkers * (nsteps - nburn) / thin) 

902 

903 # Calculate convergence boolean 

904 converged = gelman_rubin_statistic < ci.Q_tol and ci.nsamples < nsamples_effective 

905 logger.debug( 

906 "Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}".format( 

907 gelman_rubin_statistic < ci.Q_tol, ci.nsamples < nsamples_effective 

908 ) 

909 ) 

910 

911 GRAD_WINDOW_LENGTH = 2 * ((ndim + 1) // 2) + 1 

912 if GRAD_WINDOW_LENGTH <= 3: 

913 GRAD_WINDOW_LENGTH = 2 * (nwalkers // 2) + 1 

914 

915 nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int]) 

916 lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check]) 

917 check_taus = np.array(tau_list[lower_tau_index:]) 

918 if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH: 

919 gradient_tau = get_max_gradient( 

920 check_taus, axis=0, window_length=GRAD_WINDOW_LENGTH 

921 ) 

922 

923 if gradient_tau < ci.gradient_tau: 

924 logger.debug( 

925 f"tau usable as {gradient_tau} < gradient_tau={ci.gradient_tau}" 

926 ) 

927 tau_usable = True 

928 else: 

929 logger.debug( 

930 f"tau not usable as {gradient_tau} > gradient_tau={ci.gradient_tau}" 

931 ) 

932 tau_usable = False 

933 

934 check_mean_log_posterior = mean_log_posterior[:, -nsteps_to_check:] 

935 gradient_mean_log_posterior = get_max_gradient( 

936 check_mean_log_posterior, 

937 axis=1, 

938 window_length=GRAD_WINDOW_LENGTH, 

939 smooth=True, 

940 ) 

941 

942 if gradient_mean_log_posterior < ci.gradient_mean_log_posterior: 

943 logger.debug( 

944 f"tau usable as {gradient_mean_log_posterior} < " 

945 f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}" 

946 ) 

947 tau_usable *= True 

948 else: 

949 logger.debug( 

950 f"tau not usable as {gradient_mean_log_posterior} > " 

951 f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}" 

952 ) 

953 tau_usable = False 

954 

955 else: 

956 logger.debug("ACT is nan") 

957 gradient_tau = np.nan 

958 gradient_mean_log_posterior = np.nan 

959 tau_usable = False 

960 

961 if nsteps < tau_int * ci.autocorr_tol: 

962 logger.debug("ACT less than autocorr_tol") 

963 tau_usable = False 

964 elif tau_int < ci.min_tau: 

965 logger.debug("ACT less than min_tau") 

966 tau_usable = False 

967 

968 # Print an update on the progress 

969 if verbose: 

970 print_progress( 

971 iteration, 

972 sampler, 

973 time_per_check, 

974 nsamples_effective, 

975 samples_per_check, 

976 tau_int, 

977 gradient_tau, 

978 gradient_mean_log_posterior, 

979 tau_usable, 

980 convergence_inputs, 

981 gelman_rubin_statistic, 

982 ) 

983 

984 stop = converged and tau_usable 

985 return stop, nburn, thin, tau_int, nsamples_effective 

986 

987 

988def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False): 

989 """Calculate the maximum value of the gradient in the input data. 

990 

991 Applies a Savitzky-Golay filter (`scipy.signal.savgol_filter`) to the input 

992 data x, along a particular axis. This filter smooths the data and, as configured 

993 in this function, simultaneously calculates the derivative of the smoothed data. 

994 If smooth=True is provided, it will apply a Savitzky-Golay filter with a 

995 polynomial order of 3 to the input data before applying this filter a second 

996 time and calculating the derivative. This function will return the maximum value 

997 of the derivative returned by the filter. 

998 

999 See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.savgol_filter.html 

1000 for more information on the Savitzky-Golay filter that we use. Some parameter 

1001 documentation has been borrowed from this source. 

1002 

1003 Parameters 

1004 ---------- 

1005 x : np.ndarray 

1006 Array of input data (can be int or float, as `savgol_filter` casts to float). 

1007 axis : int, default = 0 

1008 The axis of the input array x over which to calculate the gradient. 

1009 window_length : int, default = 11 

1010 The length of the filter window (i.e., the number of coefficients to use 

1011 in approximating the data). 

1012 polyorder : int, default = 2 

1013 The order of the polynomial used to fit the samples. polyorder must be less 

1014 than window_length. 

1015 smooth : bool, default = False 

1016 If true, this will smooth the data with a Savitzky-Golay filter before 

1017 providing it to the Savitzky-Golay filter for calculating the derviative. 

1018 Probably useful if you think your input data is especially noisy. 

1019 

1020 Returns 

1021 ------- 

1022 max_gradient : float 

1023 Maximum value of the gradient. 

1024 """ 

1025 

1026 from scipy.signal import savgol_filter 

1027 

1028 if smooth: 

1029 x = savgol_filter(x, axis=axis, window_length=window_length, polyorder=3) 

1030 return np.max( 

1031 savgol_filter( 

1032 x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1 

1033 ) 

1034 ) 

1035 

1036 

1037def get_Q_convergence(samples): 

1038 """Calculate the Gelman-Rubin statistic as an estimate of convergence for 

1039 an ensemble of MCMC walkers. 

1040 

1041 Calculates the Gelman-Rubin statistic, from Gelman and Rubin (1992). 

1042 See section 2.2 of Gelman and Rubin (1992), at 

1043 https://doi.org/10.1214/ss/1177011136. 

1044 

1045 There is also a good description of this statistic in section 7.4.2 

1046 of "Advanced Statistical Computing" (Peng 2021), in-progress course notes, 

1047 currently found at 

1048 https://bookdown.org/rdpeng/advstatcomp/monitoring-convergence.html. 

1049 As of this writing, L in this resource refers to the number of sampling 

1050 steps remaining after some have been discarded to achieve burn-in, 

1051 equivalent to nsteps here. Paraphrasing, we compare the variance between 

1052 our walkers (chains) to the variance within each walker (compare 

1053 inter-walker vs. intra-walker variance). We do this because our walkers 

1054 should be indistinguishable from one another when they reach a steady-state, 

1055 i.e. convergence. Looking at V-hat in the definition of this function, we 

1056 can see that as nsteps -> infinity, B (inter-chain variance) -> 0, 

1057 R -> 1; so, R >~ 1 is a good condition for the convergence of our ensemble. 

1058 

1059 In practice, this function calculates the Gelman-Rubin statistic for 

1060 each dimension, and then returns the largest of these values. This 

1061 means that we can be sure that, once the walker with the largest such value 

1062 achieves a Gelman-Rubin statistic of >~ 1, the others have as well. 

1063 

1064 Parameters 

1065 ---------- 

1066 samples: np.ndarray 

1067 Array of ensemble MCMC samples, shaped like (number of walkers, number 

1068 of MCMC steps, number of dimensions). 

1069 

1070 Returns 

1071 ------- 

1072 Q: float 

1073 The largest value of the Gelman-Rubin statistic, from those values 

1074 calculated for each dimension. If only one step is represented in 

1075 samples, this returns np.inf. 

1076 """ 

1077 

1078 nwalkers, nsteps, ndim = samples.shape 

1079 if nsteps > 1: 

1080 W = np.mean(np.var(samples, axis=1), axis=0) 

1081 

1082 per_walker_mean = np.mean(samples, axis=1) 

1083 mean = np.mean(per_walker_mean, axis=0) 

1084 B = nsteps / (nwalkers - 1.0) * np.sum((per_walker_mean - mean) ** 2, axis=0) 

1085 

1086 Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B 

1087 Q_per_dim = np.sqrt(Vhat / W) 

1088 return np.max(Q_per_dim) 

1089 else: 

1090 return np.inf 

1091 

1092 

1093def print_progress( 

1094 iteration, 

1095 sampler, 

1096 time_per_check, 

1097 nsamples_effective, 

1098 samples_per_check, 

1099 tau_int, 

1100 gradient_tau, 

1101 gradient_mean_log_posterior, 

1102 tau_usable, 

1103 convergence_inputs, 

1104 Q, 

1105): 

1106 # Setup acceptance string 

1107 acceptance = sampler.acceptance_fraction[0, :] 

1108 acceptance_str = f"{np.min(acceptance):1.2f}-{np.max(acceptance):1.2f}" 

1109 

1110 # Setup tswap acceptance string 

1111 tswap_acceptance_fraction = sampler.tswap_acceptance_fraction 

1112 tswap_acceptance_str = f"{np.min(tswap_acceptance_fraction):1.2f}-{np.max(tswap_acceptance_fraction):1.2f}" 

1113 

1114 ave_time_per_check = np.mean(time_per_check[-3:]) 

1115 time_left = ( 

1116 (convergence_inputs.nsamples - nsamples_effective) 

1117 * ave_time_per_check 

1118 / samples_per_check 

1119 ) 

1120 if time_left > 0: 

1121 time_left = str(datetime.timedelta(seconds=int(time_left))) 

1122 else: 

1123 time_left = "waiting on convergence" 

1124 

1125 sampling_time = datetime.timedelta(seconds=np.sum(time_per_check)) 

1126 

1127 tau_str = f"{tau_int}(+{gradient_tau:0.2f},+{gradient_mean_log_posterior:0.2f})" 

1128 

1129 if tau_usable: 

1130 tau_str = f"={tau_str}" 

1131 else: 

1132 tau_str = f"!{tau_str}" 

1133 

1134 Q_str = f"{Q:0.2f}" 

1135 

1136 evals_per_check = ( 

1137 sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check 

1138 ) 

1139 

1140 approximate_ncalls = ( 

1141 convergence_inputs.niterations_per_check 

1142 * iteration 

1143 * sampler.nwalkers 

1144 * sampler.ntemps 

1145 ) 

1146 ncalls = f"{approximate_ncalls:1.1e}" 

1147 eval_timing = f"{1000.0 * ave_time_per_check / evals_per_check:1.2f}ms/ev" 

1148 

1149 try: 

1150 print( 

1151 f"{iteration}|{str(sampling_time).split('.')[0]}|nc:{ncalls}|" 

1152 f"a0:{acceptance_str}|swp:{tswap_acceptance_str}|" 

1153 f"n:{nsamples_effective}<{convergence_inputs.nsamples}|t{tau_str}|" 

1154 f"q:{Q_str}|{eval_timing}", 

1155 flush=True, 

1156 ) 

1157 except OSError as e: 

1158 logger.debug(f"Failed to print iteration due to :{e}") 

1159 

1160 

1161def calculate_tau_array(samples, search_parameter_keys, ci): 

1162 """Calculate the autocorrelation time for zero-temperature chains. 

1163 

1164 Calculates the autocorrelation time for each chain, for those parameters/ 

1165 dimensions that are not explicitly excluded in ci.ignore_keys_for_tau. 

1166 

1167 Parameters 

1168 ---------- 

1169 samples: np.ndarray 

1170 Array of ensemble MCMC samples, shaped like (number of walkers, number 

1171 of MCMC steps, number of dimensions). 

1172 search_parameter_keys: list 

1173 A list of the search parameter keys 

1174 ci : collections.namedtuple 

1175 Collection of settings for convergence tests, including autocorrelation 

1176 calculation. If a value in search_parameter_keys is included in 

1177 ci.ignore_keys_for_tau, this function will not calculate an 

1178 autocorrelation time for any walker along that particular dimension. 

1179 

1180 Returns 

1181 ------- 

1182 tau_array: np.ndarray 

1183 Float array shaped like (nwalkers, ndim) (with all np.inf for any 

1184 dimension that is excluded by ci.ignore_keys_for_tau). 

1185 """ 

1186 

1187 import emcee 

1188 

1189 nwalkers, nsteps, ndim = samples.shape 

1190 tau_array = np.zeros((nwalkers, ndim)) + np.inf 

1191 if nsteps > 1: 

1192 for ii in range(nwalkers): 

1193 for jj, key in enumerate(search_parameter_keys): 

1194 if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key: 

1195 continue 

1196 try: 

1197 tau_array[ii, jj] = emcee.autocorr.integrated_time( 

1198 samples[ii, :, jj], c=ci.autocorr_c, tol=0 

1199 )[0] 

1200 except emcee.autocorr.AutocorrError: 

1201 tau_array[ii, jj] = np.inf 

1202 return tau_array 

1203 

1204 

1205def checkpoint( 

1206 iteration, 

1207 outdir, 

1208 label, 

1209 nsamples_effective, 

1210 sampler, 

1211 discard, 

1212 nburn, 

1213 thin, 

1214 search_parameter_keys, 

1215 resume_file, 

1216 log_likelihood_array, 

1217 log_posterior_array, 

1218 chain_array, 

1219 pos0, 

1220 beta_list, 

1221 tau_list, 

1222 tau_list_n, 

1223 Q_list, 

1224 time_per_check, 

1225): 

1226 logger.info("Writing checkpoint and diagnostics") 

1227 ndim = sampler.dim 

1228 

1229 # Store the samples if possible 

1230 if nsamples_effective > 0: 

1231 filename = f"{outdir}/{label}_samples.txt" 

1232 samples = np.array(chain_array)[ 

1233 :, discard + nburn : iteration : thin, : 

1234 ].reshape((-1, ndim)) 

1235 df = pd.DataFrame(samples, columns=search_parameter_keys) 

1236 df.to_csv(filename, index=False, header=True, sep=" ") 

1237 

1238 # Pickle the resume artefacts 

1239 pool = sampler.pool 

1240 sampler.pool = None 

1241 sampler_copy = copy.deepcopy(sampler) 

1242 sampler.pool = pool 

1243 

1244 data = dict( 

1245 iteration=iteration, 

1246 sampler=sampler_copy, 

1247 beta_list=beta_list, 

1248 tau_list=tau_list, 

1249 tau_list_n=tau_list_n, 

1250 Q_list=Q_list, 

1251 time_per_check=time_per_check, 

1252 log_likelihood_array=log_likelihood_array, 

1253 log_posterior_array=log_posterior_array, 

1254 chain_array=chain_array, 

1255 pos0=pos0, 

1256 ) 

1257 

1258 safe_file_dump(data, resume_file, "dill") 

1259 del data, sampler_copy 

1260 logger.info("Finished writing checkpoint") 

1261 

1262 

1263def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, discard=0): 

1264 """Method to plot the trace of the walkers in an ensemble MCMC plot""" 

1265 import matplotlib.pyplot as plt 

1266 

1267 nwalkers, nsteps, ndim = walkers.shape 

1268 if np.isnan(nburn): 

1269 nburn = nsteps 

1270 if np.isnan(thin): 

1271 thin = 1 

1272 idxs = np.arange(nsteps) 

1273 fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim)) 

1274 scatter_kwargs = dict( 

1275 lw=0, 

1276 marker="o", 

1277 markersize=1, 

1278 alpha=0.1, 

1279 ) 

1280 

1281 # Plot the fixed burn-in 

1282 if discard > 0: 

1283 for i, (ax, axh) in enumerate(axes): 

1284 ax.plot( 

1285 idxs[:discard], 

1286 walkers[:, :discard, i].T, 

1287 color="gray", 

1288 **scatter_kwargs, 

1289 ) 

1290 

1291 # Plot the burn-in 

1292 for i, (ax, axh) in enumerate(axes): 

1293 ax.plot( 

1294 idxs[discard : discard + nburn + 1], 

1295 walkers[:, discard : discard + nburn + 1, i].T, 

1296 color="C1", 

1297 **scatter_kwargs, 

1298 ) 

1299 

1300 # Plot the thinned posterior samples 

1301 for i, (ax, axh) in enumerate(axes): 

1302 ax.plot( 

1303 idxs[discard + nburn :: thin], 

1304 walkers[:, discard + nburn :: thin, i].T, 

1305 color="C0", 

1306 **scatter_kwargs, 

1307 ) 

1308 axh.hist( 

1309 walkers[:, discard + nburn :: thin, i].reshape((-1)), bins=50, alpha=0.8 

1310 ) 

1311 

1312 for i, (ax, axh) in enumerate(axes): 

1313 axh.set_xlabel(parameter_labels[i]) 

1314 ax.set_ylabel(parameter_labels[i]) 

1315 

1316 fig.tight_layout() 

1317 filename = f"{outdir}/{label}_checkpoint_trace.png" 

1318 fig.savefig(filename) 

1319 plt.close(fig) 

1320 

1321 

1322def plot_tau( 

1323 tau_list_n, 

1324 tau_list, 

1325 search_parameter_keys, 

1326 outdir, 

1327 label, 

1328 tau, 

1329 autocorr_tau, 

1330): 

1331 import matplotlib.pyplot as plt 

1332 

1333 fig, ax = plt.subplots() 

1334 for i, key in enumerate(search_parameter_keys): 

1335 ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key) 

1336 ax.set_xlabel("Iteration") 

1337 ax.set_ylabel(r"$\langle \tau \rangle$") 

1338 ax.legend() 

1339 fig.tight_layout() 

1340 fig.savefig(f"{outdir}/{label}_checkpoint_tau.png") 

1341 plt.close(fig) 

1342 

1343 

1344def plot_mean_log_posterior(mean_log_posterior, outdir, label): 

1345 import matplotlib.pyplot as plt 

1346 

1347 mean_log_posterior[mean_log_posterior < -1e100] = np.nan 

1348 

1349 ntemps, nsteps = mean_log_posterior.shape 

1350 ymax = np.nanmax(mean_log_posterior) 

1351 ymin = np.nanmin(mean_log_posterior[:, -100:]) 

1352 ymax += 0.1 * (ymax - ymin) 

1353 ymin -= 0.1 * (ymax - ymin) 

1354 

1355 fig, ax = plt.subplots() 

1356 idxs = np.arange(nsteps) 

1357 ax.plot(idxs, mean_log_posterior.T) 

1358 ax.set( 

1359 xlabel="Iteration", 

1360 ylabel=r"$\langle\mathrm{log-posterior}\rangle$", 

1361 ylim=(ymin, ymax), 

1362 ) 

1363 fig.tight_layout() 

1364 fig.savefig(f"{outdir}/{label}_checkpoint_meanlogposterior.png") 

1365 plt.close(fig) 

1366 

1367 

1368def compute_evidence( 

1369 sampler, 

1370 log_likelihood_array, 

1371 outdir, 

1372 label, 

1373 discard, 

1374 nburn, 

1375 thin, 

1376 iteration, 

1377 make_plots=True, 

1378): 

1379 """Computes the evidence using thermodynamic integration""" 

1380 import matplotlib.pyplot as plt 

1381 

1382 betas = sampler.betas 

1383 # We compute the evidence without the burnin samples, but we do not thin 

1384 lnlike = log_likelihood_array[:, :, discard + nburn : iteration] 

1385 mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1) 

1386 

1387 mean_lnlikes = mean_lnlikes[::-1] 

1388 betas = betas[::-1] 

1389 

1390 if any(np.isinf(mean_lnlikes)): 

1391 logger.warning( 

1392 "mean_lnlikes contains inf: recalculating without" 

1393 f" the {len(betas[np.isinf(mean_lnlikes)])} infs" 

1394 ) 

1395 idxs = np.isinf(mean_lnlikes) 

1396 mean_lnlikes = mean_lnlikes[~idxs] 

1397 betas = betas[~idxs] 

1398 

1399 lnZ = np.trapz(mean_lnlikes, betas) 

1400 z1 = np.trapz(mean_lnlikes, betas) 

1401 z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1]) 

1402 lnZerr = np.abs(z1 - z2) 

1403 

1404 if make_plots: 

1405 fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(6, 8)) 

1406 ax1.semilogx(betas, mean_lnlikes, "-o") 

1407 ax1.set_xlabel(r"$\beta$") 

1408 ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$") 

1409 min_betas = [] 

1410 evidence = [] 

1411 for i in range(int(len(betas) / 2.0)): 

1412 min_betas.append(betas[i]) 

1413 evidence.append(np.trapz(mean_lnlikes[i:], betas[i:])) 

1414 

1415 ax2.semilogx(min_betas, evidence, "-o") 

1416 ax2.set_ylabel( 

1417 r"$\int_{\beta_{min}}^{\beta=1}" 

1418 + r"\langle \log(\mathcal{L})\rangle d\beta$", 

1419 size=16, 

1420 ) 

1421 ax2.set_xlabel(r"$\beta_{min}$") 

1422 plt.tight_layout() 

1423 fig.savefig(f"{outdir}/{label}_beta_lnl.png") 

1424 plt.close(fig) 

1425 

1426 return lnZ, lnZerr 

1427 

1428 

1429def do_nothing_function(): 

1430 """This is a do-nothing function, we overwrite the likelihood and prior elsewhere""" 

1431 pass 

1432 

1433 

1434class LikePriorEvaluator(object): 

1435 """ 

1436 This class is copied and modified from ptemcee.LikePriorEvaluator, see 

1437 https://github.com/willvousden/ptemcee for the original version 

1438 

1439 We overwrite the logl and logp methods in order to improve the performance 

1440 when using a MultiPool object: essentially reducing the amount of data 

1441 transfer overhead. 

1442 

1443 """ 

1444 

1445 def __init__(self): 

1446 self.periodic_set = False 

1447 

1448 def _setup_periodic(self): 

1449 priors = _sampling_convenience_dump.priors 

1450 search_parameter_keys = _sampling_convenience_dump.search_parameter_keys 

1451 self._periodic = [ 

1452 priors[key].boundary == "periodic" for key in search_parameter_keys 

1453 ] 

1454 priors.sample() 

1455 self._minima = np.array([priors[key].minimum for key in search_parameter_keys]) 

1456 self._range = ( 

1457 np.array([priors[key].maximum for key in search_parameter_keys]) 

1458 - self._minima 

1459 ) 

1460 self.periodic_set = True 

1461 

1462 def _wrap_periodic(self, array): 

1463 if not self.periodic_set: 

1464 self._setup_periodic() 

1465 array[self._periodic] = ( 

1466 np.mod( 

1467 array[self._periodic] - self._minima[self._periodic], 

1468 self._range[self._periodic], 

1469 ) 

1470 + self._minima[self._periodic] 

1471 ) 

1472 return array 

1473 

1474 def logl(self, v_array): 

1475 priors = _sampling_convenience_dump.priors 

1476 likelihood = _sampling_convenience_dump.likelihood 

1477 search_parameter_keys = _sampling_convenience_dump.search_parameter_keys 

1478 parameters = {key: v for key, v in zip(search_parameter_keys, v_array)} 

1479 if priors.evaluate_constraints(parameters) > 0: 

1480 likelihood.parameters.update(parameters) 

1481 if _sampling_convenience_dump.use_ratio: 

1482 return likelihood.log_likelihood() - likelihood.noise_log_likelihood() 

1483 else: 

1484 return likelihood.log_likelihood() 

1485 else: 

1486 return np.nan_to_num(-np.inf) 

1487 

1488 def logp(self, v_array): 

1489 priors = _sampling_convenience_dump.priors 

1490 search_parameter_keys = _sampling_convenience_dump.search_parameter_keys 

1491 params = {key: t for key, t in zip(search_parameter_keys, v_array)} 

1492 return priors.ln_prob(params) 

1493 

1494 def call_emcee(self, theta): 

1495 ll, lp = self.__call__(theta) 

1496 return ll + lp, [ll, lp] 

1497 

1498 def __call__(self, x): 

1499 lp = self.logp(x) 

1500 if np.isnan(lp): 

1501 raise ValueError("Prior function returned NaN.") 

1502 

1503 if lp == float("-inf"): 

1504 # Can't return -inf, since this messes with beta=0 behaviour. 

1505 ll = 0 

1506 else: 

1507 ll = self.logl(x) 

1508 if np.isnan(ll).any(): 

1509 raise ValueError("Log likelihood function returned NaN.") 

1510 

1511 return ll, lp