Coverage for bilby/bilby_mcmc/sampler.py: 57%

752 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import datetime 

2import os 

3import time 

4from collections import Counter 

5from pathlib import Path 

6 

7import numpy as np 

8import pandas as pd 

9from scipy.optimize import differential_evolution 

10 

11from ..core.result import rejection_sample 

12from ..core.sampler.base_sampler import ( 

13 MCMCSampler, 

14 ResumeError, 

15 SamplerError, 

16 _sampling_convenience_dump, 

17 signal_wrapper, 

18) 

19from ..core.utils import ( 

20 check_directory_exists_and_if_not_mkdir, 

21 logger, 

22 random, 

23 safe_file_dump, 

24) 

25from . import proposals 

26from .chain import Chain, Sample 

27from .utils import LOGLKEY, LOGPKEY, ConvergenceInputs, ParallelTemperingInputs 

28 

29 

30class Bilby_MCMC(MCMCSampler): 

31 """The built-in Bilby MCMC sampler 

32 

33 Parameters 

34 ---------- 

35 likelihood: likelihood.Likelihood 

36 A object with a log_l method 

37 priors: bilby.core.prior.PriorDict, dict 

38 Priors to be used in the search. 

39 This has attributes for each parameter to be sampled. 

40 outdir: str, optional 

41 Name of the output directory 

42 label: str, optional 

43 Naming scheme of the output files 

44 use_ratio: bool, optional 

45 Switch to set whether or not you want to use the log-likelihood ratio 

46 or just the log-likelihood 

47 skip_import_verification: bool 

48 Skips the check if the sampler is installed if true. This is 

49 only advisable for testing environments 

50 check_point_plot: bool 

51 If true, create plots at the check point 

52 check_point_delta_t: float 

53 The time in seconds afterwhich to checkpoint (defaults to 30 minutes) 

54 diagnostic: bool 

55 If true, create deep-diagnostic plots used for checking convergence 

56 problems. 

57 resume: bool 

58 If true, resume from any existing check point files 

59 exit_code: int 

60 The code on which to raise if exiting 

61 nsamples: int (1000) 

62 The number of samples to draw 

63 nensemble: int (1) 

64 The number of ensemble-chains to run (with periodic communication) 

65 pt_ensemble: bool (False) 

66 If true, each run a parallel-tempered set of chains for each 

67 ensemble-chain (in which case the total number of chains is 

68 nensemble * ntemps). Else, only the zero-ensemble chain is run with a 

69 parallel-tempering (in which case the total number of chains is 

70 nensemble + ntemps - 1). 

71 ntemps: int (1) 

72 The number of parallel-tempered chains to run 

73 Tmax: float, (None) 

74 If given, the maximum temperature to set the initial temperate-ladder 

75 Tmax_from_SNR: float (20) 

76 (Alternative to Tmax): The SNR to estimate an appropriate Tmax from. 

77 initial_betas: list (None) 

78 (Alternative to Tmax and Tmax_from_SNR): If given, an initial choice of 

79 the inverse temperature ladder. 

80 pt_rejection_sample: bool (False) 

81 If true, use rejection sampling to draw samples from the pt-chains. 

82 adapt, adapt_t0, adapt_nu: bool, float, float (True, 100, 10) 

83 Whether to use adaptation and the adaptation parameters. 

84 See arXiv:1501.05823 for a description of adapt_t0 and adapt_nu. 

85 burn_in_nact, thin_by_nact, fixed_discard: float, float, float (10, 1, 0) 

86 The number of auto-correlation times to discard for burn-in and to 

87 thin by. The fixed_discard is the number of steps discarded before 

88 automatic autocorrelation time analysis begins. 

89 autocorr_c: float (5) 

90 The step-size for the window search. See emcee.autocorr.integrated_time 

91 for additional details. 

92 L1steps: int 

93 The number of internal steps to take. Improves the scaling performance 

94 of multiprocessing. Note, all ACTs are calculated based on the saved 

95 steps. So, the total ACT (or number of steps) is L1steps * tau 

96 (or L1steps * position). 

97 L2steps: int 

98 The number of steps to take before swapping between parallel-tempered 

99 and ensemble chains. 

100 npool: int 

101 The number of multiprocessing cores to use. For efficiency, this must be 

102 matched to an integer number of the total number of chains. 

103 printdt: float 

104 Print an update on the progress every printdt s. Note, each print 

105 requires an evaluation of the ACT so short print times are unwise. 

106 min_tau: 1 

107 The minimum allowed ACT. Can be used to force a larger ACT. 

108 proposal_cycle: str, bilby.core.sampler.bilby_mcmc.proposals.ProposalCycle 

109 Either a string pointing to one of the built-in proposal cycles or, 

110 a proposal cycle. 

111 stop_after_convergence: 

112 If running with parallel-tempered chains. Stop updating the chains once 

113 they have congerged. After this time, random samples will be drawn at 

114 swap time. 

115 fixed_tau: int 

116 A fixed value for the ACT: used for testing purposes. 

117 tau_window: int, None 

118 Using tau', a previous estimates of tau, calculate the new tau using 

119 the last tau_window * tau' steps. If None, the entire chain is used. 

120 evidence_method: str, [stepping_stone, thermodynamic] 

121 The evidence calculation method to use. Defaults to stepping_stone, but 

122 the results of all available methods are stored in the ln_z_dict. 

123 initial_sample_method: str 

124 Method to draw the initial sample. Either "prior" (a random draw 

125 from the prior) or "maximize" (use an optimization approach to attempt 

126 to find the maximum posterior estimate). 

127 initial_sample_dict: dict 

128 A dictionary of the initial sample value. If incomplete, will overwrite 

129 the initial_sample drawn using initial_sample_method. 

130 normalize_prior: bool 

131 When False, disables calculation of constraint normalization factor 

132 during prior probability computation. Default value is True. 

133 verbose: bool 

134 Whether to print diagnostic output during the run. 

135 

136 """ 

137 

138 default_kwargs = dict( 

139 nsamples=1000, 

140 nensemble=1, 

141 pt_ensemble=False, 

142 ntemps=1, 

143 Tmax=None, 

144 Tmax_from_SNR=20, 

145 initial_betas=None, 

146 adapt=True, 

147 adapt_t0=100, 

148 adapt_nu=10, 

149 pt_rejection_sample=False, 

150 burn_in_nact=10, 

151 thin_by_nact=1, 

152 fixed_discard=0, 

153 autocorr_c=5, 

154 L1steps=100, 

155 L2steps=3, 

156 printdt=60, 

157 check_point_delta_t=1800, 

158 min_tau=1, 

159 proposal_cycle="default", 

160 stop_after_convergence=False, 

161 fixed_tau=None, 

162 tau_window=None, 

163 evidence_method="stepping_stone", 

164 initial_sample_method="prior", 

165 initial_sample_dict=None, 

166 ) 

167 

168 def __init__( 

169 self, 

170 likelihood, 

171 priors, 

172 outdir="outdir", 

173 label="label", 

174 use_ratio=False, 

175 skip_import_verification=True, 

176 check_point_plot=True, 

177 diagnostic=False, 

178 resume=True, 

179 exit_code=130, 

180 verbose=True, 

181 normalize_prior=True, 

182 **kwargs, 

183 ): 

184 

185 super(Bilby_MCMC, self).__init__( 

186 likelihood=likelihood, 

187 priors=priors, 

188 outdir=outdir, 

189 label=label, 

190 use_ratio=use_ratio, 

191 skip_import_verification=skip_import_verification, 

192 exit_code=exit_code, 

193 **kwargs, 

194 ) 

195 

196 self.check_point_plot = check_point_plot 

197 self.diagnostic = diagnostic 

198 self.kwargs["target_nsamples"] = self.kwargs["nsamples"] 

199 self.L1steps = self.kwargs["L1steps"] 

200 self.L2steps = self.kwargs["L2steps"] 

201 self.normalize_prior = normalize_prior 

202 self.pt_inputs = ParallelTemperingInputs( 

203 **{key: self.kwargs[key] for key in ParallelTemperingInputs._fields} 

204 ) 

205 self.convergence_inputs = ConvergenceInputs( 

206 **{key: self.kwargs[key] for key in ConvergenceInputs._fields} 

207 ) 

208 self.proposal_cycle = self.kwargs["proposal_cycle"] 

209 self.pt_rejection_sample = self.kwargs["pt_rejection_sample"] 

210 self.evidence_method = self.kwargs["evidence_method"] 

211 self.initial_sample_method = self.kwargs["initial_sample_method"] 

212 self.initial_sample_dict = self.kwargs["initial_sample_dict"] 

213 

214 self.printdt = self.kwargs["printdt"] 

215 self.check_point_delta_t = self.kwargs["check_point_delta_t"] 

216 check_directory_exists_and_if_not_mkdir(self.outdir) 

217 self.resume = resume 

218 self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label) 

219 

220 self.verify_configuration() 

221 self.verbose = verbose 

222 

223 def verify_configuration(self): 

224 if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1: 

225 logger.warning("Burn-in inefficiency fraction greater than 10%") 

226 

227 def _translate_kwargs(self, kwargs): 

228 kwargs = super()._translate_kwargs(kwargs) 

229 if "printdt" not in kwargs: 

230 for equiv in ["print_dt", "print_update"]: 

231 if equiv in kwargs: 

232 kwargs["printdt"] = kwargs.pop(equiv) 

233 if "npool" not in kwargs: 

234 for equiv in self.npool_equiv_kwargs: 

235 if equiv in kwargs: 

236 kwargs["npool"] = kwargs.pop(equiv) 

237 if "check_point_delta_t" not in kwargs: 

238 for equiv in self.check_point_equiv_kwargs: 

239 if equiv in kwargs: 

240 kwargs["check_point_delta_t"] = kwargs.pop(equiv) 

241 

242 @property 

243 def target_nsamples(self): 

244 return self.kwargs["target_nsamples"] 

245 

246 @signal_wrapper 

247 def run_sampler(self): 

248 self._setup_pool() 

249 self.setup_chain_set() 

250 self.start_time = datetime.datetime.now() 

251 self.draw() 

252 self._close_pool() 

253 self.check_point(ignore_time=True) 

254 

255 self.result = self.add_data_to_result( 

256 result=self.result, 

257 ptsampler=self.ptsampler, 

258 outdir=self.outdir, 

259 label=self.label, 

260 make_plots=self.check_point_plot, 

261 ) 

262 

263 return self.result 

264 

265 @staticmethod 

266 def add_data_to_result(result, ptsampler, outdir, label, make_plots): 

267 result.samples = ptsampler.samples 

268 result.log_likelihood_evaluations = result.samples[LOGLKEY].to_numpy() 

269 result.log_prior_evaluations = result.samples[LOGPKEY].to_numpy() 

270 ptsampler.compute_evidence( 

271 outdir=outdir, 

272 label=label, 

273 make_plots=make_plots, 

274 ) 

275 result.log_evidence = ptsampler.ln_z 

276 result.log_evidence_err = ptsampler.ln_z_err 

277 result.sampling_time = datetime.timedelta(seconds=ptsampler.sampling_time) 

278 result.meta_data["bilby_mcmc"] = dict( 

279 tau=ptsampler.tau, 

280 convergence_inputs=ptsampler.convergence_inputs._asdict(), 

281 pt_inputs=ptsampler.pt_inputs._asdict(), 

282 total_steps=ptsampler.position, 

283 nsamples=ptsampler.nsamples, 

284 ) 

285 if ptsampler.pool is not None: 

286 npool = ptsampler.pool._processes 

287 else: 

288 npool = 1 

289 result.meta_data["run_statistics"] = dict( 

290 nlikelihood=ptsampler.position * ptsampler.L1steps * ptsampler._nsamplers, 

291 neffsamples=ptsampler.nsamples * ptsampler.convergence_inputs.thin_by_nact, 

292 sampling_time_s=result.sampling_time.seconds, 

293 ncores=npool, 

294 ) 

295 

296 return result 

297 

298 def setup_chain_set(self): 

299 if self.read_current_state() and self.resume is True: 

300 self.ptsampler.pool = self.pool 

301 else: 

302 self.init_ptsampler() 

303 

304 def init_ptsampler(self): 

305 

306 logger.info(f"Initializing BilbyPTMCMCSampler with:\n{self.get_setup_string()}") 

307 self.ptsampler = BilbyPTMCMCSampler( 

308 convergence_inputs=self.convergence_inputs, 

309 pt_inputs=self.pt_inputs, 

310 proposal_cycle=self.proposal_cycle, 

311 pt_rejection_sample=self.pt_rejection_sample, 

312 pool=self.pool, 

313 use_ratio=self.use_ratio, 

314 evidence_method=self.evidence_method, 

315 initial_sample_method=self.initial_sample_method, 

316 initial_sample_dict=self.initial_sample_dict, 

317 normalize_prior=self.normalize_prior, 

318 ) 

319 

320 def get_setup_string(self): 

321 string = ( 

322 f" Convergence settings: {self.convergence_inputs}\n" 

323 f" Parallel-tempering settings: {self.pt_inputs}\n" 

324 f" proposal_cycle: {self.proposal_cycle}\n" 

325 f" pt_rejection_sample: {self.pt_rejection_sample}" 

326 ) 

327 return string 

328 

329 def draw(self): 

330 self._steps_since_last_print = 0 

331 self._time_since_last_print = 0 

332 logger.info(f"Drawing {self.target_nsamples} samples") 

333 logger.info(f"Checkpoint every check_point_delta_t={self.check_point_delta_t}s") 

334 logger.info(f"Print update every printdt={self.printdt}s") 

335 

336 while True: 

337 t0 = datetime.datetime.now() 

338 self.ptsampler.step_all_chains() 

339 dt = (datetime.datetime.now() - t0).total_seconds() 

340 self.ptsampler.sampling_time += dt 

341 self._time_since_last_print += dt 

342 self._steps_since_last_print += self.ptsampler.L1steps 

343 

344 if self._time_since_last_print > self.printdt: 

345 tp0 = datetime.datetime.now() 

346 self.print_progress() 

347 tp = datetime.datetime.now() 

348 ppt_frac = (tp - tp0).total_seconds() / self._time_since_last_print 

349 if ppt_frac > 0.01: 

350 logger.warning( 

351 f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})" 

352 ) 

353 self._steps_since_last_print = 0 

354 self._time_since_last_print = 0 

355 

356 self.check_point() 

357 

358 if self.ptsampler.nsamples_last >= self.target_nsamples: 

359 # Perform a second check without cached values 

360 if self.ptsampler.nsamples_nocache >= self.target_nsamples: 

361 logger.info("Reached convergence: exiting sampling") 

362 break 

363 

364 def check_point(self, ignore_time=False): 

365 tS = (datetime.datetime.now() - self.start_time).total_seconds() 

366 if os.path.isfile(self.resume_file): 

367 tR = time.time() - os.path.getmtime(self.resume_file) 

368 else: 

369 tR = np.inf 

370 

371 if ignore_time or np.min([tS, tR]) > self.check_point_delta_t: 

372 logger.info("Checkpoint start") 

373 self.write_current_state() 

374 self.print_long_progress() 

375 logger.info("Checkpoint finished") 

376 

377 def _remove_checkpoint(self): 

378 """Remove checkpointed state""" 

379 if os.path.isfile(self.resume_file): 

380 os.remove(self.resume_file) 

381 

382 def read_current_state(self): 

383 """Read the existing resume file 

384 

385 Returns 

386 ------- 

387 success: boolean 

388 If true, resume file was successfully loaded, otherwise false 

389 

390 """ 

391 if os.path.isfile(self.resume_file) is False or not os.path.getsize( 

392 self.resume_file 

393 ): 

394 return False 

395 import dill 

396 

397 with open(self.resume_file, "rb") as file: 

398 ptsampler = dill.load(file) 

399 if not isinstance(ptsampler, BilbyPTMCMCSampler): 

400 logger.debug("Malformed resume file, ignoring") 

401 return False 

402 self.ptsampler = ptsampler 

403 if self.ptsampler.pt_inputs != self.pt_inputs: 

404 msg = ( 

405 f"pt_inputs has changed: {self.ptsampler.pt_inputs} " 

406 f"-> {self.pt_inputs}" 

407 ) 

408 raise ResumeError(msg) 

409 self.ptsampler.set_convergence_inputs(self.convergence_inputs) 

410 self.ptsampler.pt_rejection_sample = self.pt_rejection_sample 

411 

412 logger.info( 

413 f"Loaded resume file {self.resume_file} " 

414 f"with {self.ptsampler.position} steps " 

415 f"setup:\n{self.get_setup_string()}" 

416 ) 

417 return True 

418 

419 def write_current_state(self): 

420 import dill 

421 

422 if not hasattr(self, "ptsampler"): 

423 logger.debug("Attempted checkpoint before initialization") 

424 return 

425 logger.debug("Check point") 

426 check_directory_exists_and_if_not_mkdir(self.outdir) 

427 

428 _pool = self.ptsampler.pool 

429 self.ptsampler.pool = None 

430 if dill.pickles(self.ptsampler): 

431 safe_file_dump(self.ptsampler, self.resume_file, dill) 

432 logger.info("Written checkpoint file {}".format(self.resume_file)) 

433 else: 

434 logger.warning( 

435 "Cannot write pickle resume file! Job may not resume if interrupted." 

436 ) 

437 # Touch the file to postpone next check-point attempt 

438 Path(self.resume_file).touch(exist_ok=True) 

439 self.ptsampler.pool = _pool 

440 

441 def print_long_progress(self): 

442 self.print_per_proposal() 

443 self.print_tau_dict() 

444 if self.ptsampler.ntemps > 1: 

445 self.print_pt_acceptance() 

446 if self.ptsampler.nensemble > 1: 

447 self.print_ensemble_acceptance() 

448 if self.check_point_plot: 

449 self.plot_progress( 

450 self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic 

451 ) 

452 self.ptsampler.compute_evidence( 

453 outdir=self.outdir, label=self.label, make_plots=True 

454 ) 

455 

456 def print_ensemble_acceptance(self): 

457 logger.info(f"Ensemble swaps = {self.ptsampler.swap_counter['ensemble']}") 

458 logger.info(self.ptsampler.ensemble_proposal_cycle) 

459 

460 def print_progress(self): 

461 position = self.ptsampler.position 

462 

463 # Total sampling time 

464 sampling_time = datetime.timedelta(seconds=self.ptsampler.sampling_time) 

465 time = str(sampling_time).split(".")[0] 

466 

467 # Time for last evaluation set 

468 time_per_eval_ms = ( 

469 1000 * self._time_since_last_print / self._steps_since_last_print 

470 ) 

471 

472 # Pull out progress summary 

473 tau = self.ptsampler.tau 

474 nsamples = self.ptsampler.nsamples 

475 minimum_index = self.ptsampler.primary_sampler.chain.minimum_index 

476 method = self.ptsampler.primary_sampler.chain.minimum_index_method 

477 mindex_str = f"{minimum_index:0.2e}({method})" 

478 alpha = self.ptsampler.primary_sampler.acceptance_ratio 

479 maxl = self.ptsampler.primary_sampler.chain.max_log_likelihood 

480 

481 nlikelihood = position * self.L1steps * self.ptsampler._nsamplers 

482 eff = 100 * nsamples / nlikelihood 

483 

484 # Estimated time til finish (ETF) 

485 if tau < np.inf: 

486 remaining_samples = self.target_nsamples - nsamples 

487 remaining_evals = ( 

488 remaining_samples 

489 * self.convergence_inputs.thin_by_nact 

490 * tau 

491 * self.L1steps 

492 ) 

493 remaining_time_s = time_per_eval_ms * 1e-3 * remaining_evals 

494 remaining_time_dt = datetime.timedelta(seconds=remaining_time_s) 

495 if remaining_samples > 0: 

496 remaining_time = str(remaining_time_dt).split(".")[0] 

497 else: 

498 remaining_time = "0" 

499 else: 

500 remaining_time = "-" 

501 

502 msg = ( 

503 f"{position:0.2e}|{time}|{mindex_str}|t={tau:0.0f}|" 

504 f"n={nsamples:0.0f}|a={alpha:0.2f}|e={eff:0.1e}%|" 

505 f"{time_per_eval_ms:0.2f}ms/ev|maxl={maxl:0.2f}|" 

506 f"ETF={remaining_time}" 

507 ) 

508 

509 if self.pt_rejection_sample: 

510 count = self.ptsampler.rejection_sampling_count 

511 rse = 100 * count / nsamples 

512 msg += f"|rse={rse:0.2f}%" 

513 

514 if self.verbose: 

515 print(msg, flush=True) 

516 

517 def print_per_proposal(self): 

518 logger.info("Zero-temperature proposals:") 

519 for prop in self.ptsampler[0].proposal_cycle.proposal_list: 

520 logger.info(prop) 

521 

522 def print_pt_acceptance(self): 

523 logger.info(f"Temperature swaps = {self.ptsampler.swap_counter['temperature']}") 

524 for column in self.ptsampler.sampler_list_of_tempered_lists: 

525 for ii, sampler in enumerate(column): 

526 total = sampler.pt_accepted + sampler.pt_rejected 

527 beta = sampler.beta 

528 if total > 0: 

529 ratio = f"{sampler.pt_accepted / total:0.2f}" 

530 else: 

531 ratio = "-" 

532 logger.info( 

533 f"Temp:{ii}<->{ii+1}|" 

534 f"beta={beta:0.4g}|" 

535 f"hot-samp={sampler.nsamples}|" 

536 f"swap={ratio}|" 

537 f"conv={sampler.chain.converged}|" 

538 ) 

539 

540 def print_tau_dict(self): 

541 msg = f"Current taus={self.ptsampler.primary_sampler.chain.tau_dict}" 

542 logger.info(msg) 

543 

544 @staticmethod 

545 def plot_progress(ptsampler, label, outdir, priors, diagnostic=False): 

546 logger.info("Creating diagnostic plots") 

547 for ii, row in ptsampler.sampler_dictionary.items(): 

548 for jj, sampler in enumerate(row): 

549 plot_label = f"{label}_E{sampler.Eindex}_T{sampler.Tindex}" 

550 if diagnostic is True or sampler.beta == 1: 

551 sampler.chain.plot( 

552 outdir=outdir, 

553 label=plot_label, 

554 priors=priors, 

555 all_samples=ptsampler.samples, 

556 ) 

557 

558 @classmethod 

559 def get_expected_outputs(cls, outdir=None, label=None): 

560 """Get lists of the expected outputs directories and files. 

561 

562 These are used by :code:`bilby_pipe` when transferring files via HTCondor. 

563 

564 Parameters 

565 ---------- 

566 outdir : str 

567 The output directory. 

568 label : str 

569 The label for the run. 

570 

571 Returns 

572 ------- 

573 list 

574 List of file names. 

575 list 

576 List of directory names. Will always be empty for bilby_mcmc. 

577 """ 

578 filenames = [os.path.join(outdir, f"{label}_resume.pickle")] 

579 return filenames, [] 

580 

581 

582class BilbyPTMCMCSampler(object): 

583 def __init__( 

584 self, 

585 convergence_inputs, 

586 pt_inputs, 

587 proposal_cycle, 

588 pt_rejection_sample, 

589 pool, 

590 use_ratio, 

591 evidence_method, 

592 initial_sample_method, 

593 initial_sample_dict, 

594 normalize_prior=True, 

595 ): 

596 self.set_pt_inputs(pt_inputs) 

597 self.use_ratio = use_ratio 

598 self.initial_sample_method = initial_sample_method 

599 self.initial_sample_dict = initial_sample_dict 

600 self.normalize_prior = normalize_prior 

601 self.setup_sampler_dictionary(convergence_inputs, proposal_cycle) 

602 self.set_convergence_inputs(convergence_inputs) 

603 self.pt_rejection_sample = pt_rejection_sample 

604 self.pool = pool 

605 self.evidence_method = evidence_method 

606 

607 # Initialize counters 

608 self.swap_counter = Counter() 

609 self.swap_counter["temperature"] = 0 

610 self.swap_counter["L2-temperature"] = 0 

611 self.swap_counter["ensemble"] = 0 

612 self.swap_counter["L2-ensemble"] = int(self.L2steps / 2) + 1 

613 

614 self._nsamples_dict = {} 

615 self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle( 

616 _sampling_convenience_dump.priors 

617 ) 

618 self.sampling_time = 0 

619 self.ln_z_dict = dict() 

620 self.ln_z_err_dict = dict() 

621 

622 def get_initial_betas(self): 

623 pt_inputs = self.pt_inputs 

624 if self.ntemps == 1: 

625 betas = np.array([1]) 

626 elif pt_inputs.initial_betas is not None: 

627 betas = np.array(pt_inputs.initial_betas) 

628 elif pt_inputs.Tmax is not None: 

629 betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps) 

630 elif pt_inputs.Tmax_from_SNR is not None: 

631 ndim = len(_sampling_convenience_dump.priors.non_fixed_keys) 

632 target_hot_likelihood = ndim / 2 

633 Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood) 

634 betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps) 

635 else: 

636 raise SamplerError("Unable to set temperature ladder from inputs") 

637 

638 if len(betas) != self.ntemps: 

639 raise SamplerError("Temperatures do not match ntemps") 

640 

641 return betas 

642 

643 def setup_sampler_dictionary(self, convergence_inputs, proposal_cycle): 

644 

645 betas = self.get_initial_betas() 

646 logger.info( 

647 f"Initializing BilbyPTMCMCSampler with:" 

648 f"ntemps={self.ntemps}, " 

649 f"nensemble={self.nensemble}, " 

650 f"pt_ensemble={self.pt_ensemble}, " 

651 f"initial_betas={betas}, " 

652 f"initial_sample_method={self.initial_sample_method}, " 

653 f"initial_sample_dict={self.initial_sample_dict}\n" 

654 ) 

655 self.sampler_dictionary = dict() 

656 for Tindex, beta in enumerate(betas): 

657 if beta == 1 or self.pt_ensemble: 

658 n = self.nensemble 

659 else: 

660 n = 1 

661 temp_sampler_list = [ 

662 BilbyMCMCSampler( 

663 beta=beta, 

664 Tindex=Tindex, 

665 Eindex=Eindex, 

666 convergence_inputs=convergence_inputs, 

667 proposal_cycle=proposal_cycle, 

668 use_ratio=self.use_ratio, 

669 initial_sample_method=self.initial_sample_method, 

670 initial_sample_dict=self.initial_sample_dict, 

671 normalize_prior=self.normalize_prior, 

672 ) 

673 for Eindex in range(n) 

674 ] 

675 self.sampler_dictionary[Tindex] = temp_sampler_list 

676 

677 # Store data 

678 self._nsamplers = len(self.sampler_list) 

679 

680 @property 

681 def sampler_list(self): 

682 """A list of all individual samplers""" 

683 return [s for item in self.sampler_dictionary.values() for s in item] 

684 

685 @sampler_list.setter 

686 def sampler_list(self, sampler_list): 

687 for sampler in sampler_list: 

688 self.sampler_dictionary[sampler.Tindex][sampler.Eindex] = sampler 

689 

690 def sampler_list_by_column(self, column): 

691 return [row[column] for row in self.sampler_dictionary.values()] 

692 

693 @property 

694 def sampler_list_of_tempered_lists(self): 

695 if self.pt_ensemble: 

696 return [self.sampler_list_by_column(ii) for ii in range(self.nensemble)] 

697 else: 

698 return [self.sampler_list_by_column(0)] 

699 

700 @property 

701 def tempered_sampler_list(self): 

702 return [s for s in self.sampler_list if s.beta < 1] 

703 

704 @property 

705 def zerotemp_sampler_list(self): 

706 return [s for s in self.sampler_list if s.beta == 1] 

707 

708 @property 

709 def primary_sampler(self): 

710 return self.sampler_dictionary[0][0] 

711 

712 def set_pt_inputs(self, pt_inputs): 

713 logger.info(f"Setting parallel tempering inputs={pt_inputs}") 

714 self.pt_inputs = pt_inputs 

715 

716 # Pull out only what is needed 

717 self.ntemps = pt_inputs.ntemps 

718 self.nensemble = pt_inputs.nensemble 

719 self.pt_ensemble = pt_inputs.pt_ensemble 

720 self.adapt = pt_inputs.adapt 

721 self.adapt_t0 = pt_inputs.adapt_t0 

722 self.adapt_nu = pt_inputs.adapt_nu 

723 

724 def set_convergence_inputs(self, convergence_inputs): 

725 logger.info(f"Setting convergence_inputs={convergence_inputs}") 

726 self.convergence_inputs = convergence_inputs 

727 self.L1steps = convergence_inputs.L1steps 

728 self.L2steps = convergence_inputs.L2steps 

729 for sampler in self.sampler_list: 

730 sampler.set_convergence_inputs(convergence_inputs) 

731 

732 @property 

733 def tau(self): 

734 return self.primary_sampler.chain.tau 

735 

736 @property 

737 def minimum_index(self): 

738 return self.primary_sampler.chain.minimum_index 

739 

740 @property 

741 def nsamples(self): 

742 pos = self.primary_sampler.chain.position 

743 if hasattr(self, "_nsamples_dict") is False: 

744 self._nsamples_dict = {} 

745 if pos in self._nsamples_dict: 

746 return self._nsamples_dict[pos] 

747 logger.debug(f"Calculating nsamples at {pos}") 

748 self._nsamples_dict[pos] = self._calculate_nsamples() 

749 return self._nsamples_dict[pos] 

750 

751 @property 

752 def nsamples_last(self): 

753 if len(self._nsamples_dict) > 0: 

754 return list(self._nsamples_dict.values())[-1] 

755 else: 

756 return 0 

757 

758 @property 

759 def nsamples_nocache(self): 

760 for sampler in self.sampler_list: 

761 sampler.chain.tau_nocache 

762 pos = self.primary_sampler.chain.position 

763 self._nsamples_dict[pos] = self._calculate_nsamples() 

764 return self._nsamples_dict[pos] 

765 

766 def _calculate_nsamples(self): 

767 nsamples_list = [] 

768 for sampler in self.zerotemp_sampler_list: 

769 nsamples_list.append(sampler.nsamples) 

770 if self.pt_rejection_sample: 

771 for samp in self.sampler_list[1:]: 

772 nsamples_list.append( 

773 len(samp.rejection_sample_zero_temperature_samples()) 

774 ) 

775 return sum(nsamples_list) 

776 

777 @property 

778 def samples(self): 

779 cached_samples = getattr(self, "_cached_samples", (False,)) 

780 if cached_samples[0] == self.position: 

781 return cached_samples[1] 

782 

783 sample_list = [] 

784 for sampler in self.zerotemp_sampler_list: 

785 sample_list.append(sampler.samples) 

786 if self.pt_rejection_sample: 

787 for sampler in self.tempered_sampler_list: 

788 sample_list.append(sampler.samples) 

789 samples = pd.concat(sample_list, ignore_index=True) 

790 self._cached_samples = (self.position, samples) 

791 return samples 

792 

793 @property 

794 def position(self): 

795 return self.primary_sampler.chain.position 

796 

797 @property 

798 def evaluations(self): 

799 return int(self.position * len(self.sampler_list)) 

800 

801 def __getitem__(self, index): 

802 return self.sampler_list[index] 

803 

804 def step_all_chains(self): 

805 if self.pool: 

806 self.sampler_list = self.pool.map(call_step, self.sampler_list) 

807 else: 

808 for ii, sampler in enumerate(self.sampler_list): 

809 self.sampler_list[ii] = sampler.step() 

810 

811 if self.nensemble > 1 and self.swap_counter["L2-ensemble"] >= self.L2steps: 

812 self.swap_counter["ensemble"] += 1 

813 self.swap_counter["L2-ensemble"] = 0 

814 self.ensemble_step() 

815 

816 if self.ntemps > 1 and self.swap_counter["L2-temperature"] >= self.L2steps: 

817 self.swap_counter["temperature"] += 1 

818 self.swap_counter["L2-temperature"] = 0 

819 self.swap_tempered_chains() 

820 if self.position < self.adapt_t0 * 10: 

821 if self.adapt: 

822 self.adapt_temperatures() 

823 elif self.adapt: 

824 logger.info( 

825 f"Adaptation of temperature chains finished at step {self.position}" 

826 ) 

827 self.adapt = False 

828 

829 self.swap_counter["L2-ensemble"] += 1 

830 self.swap_counter["L2-temperature"] += 1 

831 

832 @staticmethod 

833 def _get_sample_to_swap(sampler): 

834 if not (sampler.chain.converged and sampler.stop_after_convergence): 

835 v = sampler.chain[-1] 

836 else: 

837 v = sampler.chain.random_sample 

838 logl = v[LOGLKEY] 

839 return v, logl 

840 

841 def swap_tempered_chains(self): 

842 if self.pt_ensemble: 

843 Eindexs = range(self.nensemble) 

844 else: 

845 Eindexs = [0] 

846 for Eindex in Eindexs: 

847 for Tindex in range(self.ntemps - 1): 

848 sampleri = self.sampler_dictionary[Tindex][Eindex] 

849 vi, logli = self._get_sample_to_swap(sampleri) 

850 betai = sampleri.beta 

851 

852 samplerj = self.sampler_dictionary[Tindex + 1][Eindex] 

853 vj, loglj = self._get_sample_to_swap(samplerj) 

854 betaj = samplerj.beta 

855 

856 dbeta = betaj - betai 

857 with np.errstate(over="ignore"): 

858 alpha_swap = np.exp(dbeta * (logli - loglj)) 

859 

860 if random.rng.uniform(0, 1) <= alpha_swap: 

861 sampleri.chain[-1] = vj 

862 samplerj.chain[-1] = vi 

863 self.sampler_dictionary[Tindex][Eindex] = sampleri 

864 self.sampler_dictionary[Tindex + 1][Eindex] = samplerj 

865 sampleri.pt_accepted += 1 

866 else: 

867 sampleri.pt_rejected += 1 

868 

869 def ensemble_step(self): 

870 for Tindex, sampler_list in self.sampler_dictionary.items(): 

871 if len(sampler_list) > 1: 

872 for Eindex, sampler in enumerate(sampler_list): 

873 curr = sampler.chain.current_sample 

874 proposal = self.ensemble_proposal_cycle.get_proposal() 

875 complement = [s.chain for s in sampler_list if s != sampler] 

876 prop, log_factor = proposal(sampler.chain, complement) 

877 logp = sampler.log_prior(prop) 

878 

879 if logp == -np.inf: 

880 sampler.reject_proposal(curr, proposal) 

881 self.sampler_dictionary[Tindex][Eindex] = sampler 

882 continue 

883 

884 prop[LOGPKEY] = logp 

885 prop[LOGLKEY] = sampler.log_likelihood(prop) 

886 alpha = np.exp( 

887 log_factor 

888 + sampler.beta * prop[LOGLKEY] 

889 + prop[LOGPKEY] 

890 - sampler.beta * curr[LOGLKEY] 

891 - curr[LOGPKEY] 

892 ) 

893 

894 if random.rng.uniform(0, 1) <= alpha: 

895 sampler.accept_proposal(prop, proposal) 

896 else: 

897 sampler.reject_proposal(curr, proposal) 

898 self.sampler_dictionary[Tindex][Eindex] = sampler 

899 

900 def adapt_temperatures(self): 

901 """Adapt the temperature of the chains 

902 

903 Using the dynamic temperature selection described in arXiv:1501.05823, 

904 adapt the chains to target a constant swap ratio. This method is based 

905 on github.com/willvousden/ptemcee/tree/master/ptemcee 

906 """ 

907 

908 self.primary_sampler.chain.minimum_index_adapt = self.position 

909 tt = self.swap_counter["temperature"] 

910 for sampler_list in self.sampler_list_of_tempered_lists: 

911 betas = np.array([s.beta for s in sampler_list]) 

912 ratios = np.array([s.acceptance_ratio for s in sampler_list[:-1]]) 

913 

914 # Modulate temperature adjustments with a hyperbolic decay. 

915 decay = self.adapt_t0 / (tt + self.adapt_t0) 

916 kappa = decay / self.adapt_nu 

917 

918 # Construct temperature adjustments. 

919 dSs = kappa * (ratios[:-1] - ratios[1:]) 

920 

921 # Compute new ladder (hottest and coldest chains don't move). 

922 deltaTs = np.diff(1 / betas[:-1]) 

923 deltaTs *= np.exp(dSs) 

924 betas[1:-1] = 1 / (np.cumsum(deltaTs) + 1 / betas[0]) 

925 for sampler, beta in zip(sampler_list, betas): 

926 sampler.beta = beta 

927 

928 @property 

929 def ln_z(self): 

930 return self.ln_z_dict.get(self.evidence_method, np.nan) 

931 

932 @property 

933 def ln_z_err(self): 

934 return self.ln_z_err_dict.get(self.evidence_method, np.nan) 

935 

936 def compute_evidence(self, outdir, label, make_plots=True): 

937 if self.ntemps == 1: 

938 return 

939 kwargs = dict(outdir=outdir, label=label, make_plots=make_plots) 

940 methods = dict( 

941 thermodynamic=self.thermodynamic_integration_evidence, 

942 stepping_stone=self.stepping_stone_evidence, 

943 ) 

944 for key, method in methods.items(): 

945 ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs) 

946 self.ln_z_dict[key] = ln_z 

947 self.ln_z_err_dict[key] = ln_z_err 

948 logger.debug( 

949 f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method" 

950 ) 

951 

952 def compute_evidence_per_ensemble(self, method, kwargs): 

953 from scipy.special import logsumexp 

954 

955 if self.ntemps == 1: 

956 return np.nan, np.nan 

957 

958 lnZ_list = [] 

959 lnZerr_list = [] 

960 for index, ptchain in enumerate(self.sampler_list_of_tempered_lists): 

961 lnZ, lnZerr = method(ptchain, **kwargs) 

962 lnZ_list.append(lnZ) 

963 lnZerr_list.append(lnZerr) 

964 

965 N = len(lnZ_list) 

966 

967 # Average lnZ 

968 lnZ = logsumexp(lnZ_list, b=1.0 / N) 

969 

970 # Propagate uncertainty in combined evidence 

971 lnZerr = 0.5 * logsumexp(2 * np.array(lnZerr_list), b=1.0 / N) 

972 

973 return lnZ, lnZerr 

974 

975 def thermodynamic_integration_evidence( 

976 self, ptchain, outdir, label, make_plots=True 

977 ): 

978 """Computes the evidence using thermodynamic integration 

979 

980 We compute the evidence without the burnin samples, no thinning 

981 """ 

982 from scipy.stats import sem 

983 

984 betas = [] 

985 mean_lnlikes = [] 

986 sem_lnlikes = [] 

987 for sampler in ptchain: 

988 lnlikes = sampler.chain.get_1d_array(LOGLKEY) 

989 mindex = sampler.chain.minimum_index 

990 lnlikes = lnlikes[mindex:] 

991 mean_lnlikes.append(np.mean(lnlikes)) 

992 sem_lnlikes.append(sem(lnlikes)) 

993 betas.append(sampler.beta) 

994 

995 # Convert to array and re-order 

996 betas = np.array(betas)[::-1] 

997 mean_lnlikes = np.array(mean_lnlikes)[::-1] 

998 sem_lnlikes = np.array(sem_lnlikes)[::-1] 

999 

1000 lnZ, lnZerr = self._compute_evidence_from_mean_lnlikes(betas, mean_lnlikes) 

1001 

1002 if make_plots: 

1003 plot_label = f"{label}_E{ptchain[0].Eindex}" 

1004 self._create_lnZ_plots( 

1005 betas=betas, 

1006 mean_lnlikes=mean_lnlikes, 

1007 outdir=outdir, 

1008 label=plot_label, 

1009 sem_lnlikes=sem_lnlikes, 

1010 ) 

1011 

1012 return lnZ, lnZerr 

1013 

1014 def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True): 

1015 """ 

1016 Compute the evidence using the stepping stone approximation. 

1017 

1018 See https://arxiv.org/abs/1810.04488 and 

1019 https://pubmed.ncbi.nlm.nih.gov/21187451/ for details. 

1020 

1021 The uncertainty calculation is hopefully combining the evidence in each 

1022 of the steps. 

1023 

1024 Returns 

1025 ------- 

1026 ln_z: float 

1027 Estimate of the natural log evidence 

1028 ln_z_err: float 

1029 Estimate of the uncertainty in the evidence 

1030 """ 

1031 # Order in increasing beta 

1032 ptchain.reverse() 

1033 

1034 # Get maximum usable set of samples across the ptchain 

1035 min_index = max([samp.chain.minimum_index for samp in ptchain]) 

1036 max_index = min([len(samp.chain.get_1d_array(LOGLKEY)) for samp in ptchain]) 

1037 tau = self.tau 

1038 

1039 if max_index - min_index <= 1 or np.isinf(tau): 

1040 return np.nan, np.nan 

1041 

1042 # Read in log likelihoods 

1043 ln_likes = np.array( 

1044 [samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain] 

1045 )[:-1].T 

1046 

1047 # Thin to only independent samples 

1048 ln_likes = ln_likes[:: int(self.tau), :] 

1049 steps = ln_likes.shape[0] 

1050 

1051 # Calculate delta betas 

1052 betas = np.array([samp.beta for samp in ptchain]) 

1053 

1054 ln_z, ln_ratio = self._calculate_stepping_stone(betas, ln_likes) 

1055 

1056 # Implementation of the bootstrap method described in Maturana-Russel 

1057 # et. al. (2019) to estimate the evidence uncertainty. 

1058 ll = 50 # Block length 

1059 repeats = 100 # Repeats 

1060 ln_z_realisations = [] 

1061 try: 

1062 for _ in range(repeats): 

1063 idxs = [random.rng.integers(i, i + ll) for i in range(steps - ll)] 

1064 ln_z_realisations.append( 

1065 self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0] 

1066 ) 

1067 ln_z_err = np.std(ln_z_realisations) 

1068 except ValueError: 

1069 logger.info("Failed to estimate stepping stone uncertainty") 

1070 ln_z_err = np.nan 

1071 

1072 if make_plots: 

1073 plot_label = f"{label}_E{ptchain[0].Eindex}" 

1074 self._create_stepping_stone_plot( 

1075 means=ln_ratio, 

1076 outdir=outdir, 

1077 label=plot_label, 

1078 ) 

1079 

1080 return ln_z, ln_z_err 

1081 

1082 @staticmethod 

1083 def _calculate_stepping_stone(betas, ln_likes): 

1084 from scipy.special import logsumexp 

1085 

1086 n_samples = ln_likes.shape[0] 

1087 d_betas = betas[1:] - betas[:-1] 

1088 ln_ratio = logsumexp(d_betas * ln_likes, axis=0) - np.log(n_samples) 

1089 return sum(ln_ratio), ln_ratio 

1090 

1091 @staticmethod 

1092 def _compute_evidence_from_mean_lnlikes(betas, mean_lnlikes): 

1093 lnZ = np.trapz(mean_lnlikes, betas) 

1094 z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1]) 

1095 lnZerr = np.abs(lnZ - z2) 

1096 return lnZ, lnZerr 

1097 

1098 def _create_lnZ_plots(self, betas, mean_lnlikes, outdir, label, sem_lnlikes=None): 

1099 import matplotlib.pyplot as plt 

1100 

1101 logger.debug("Creating thermodynamic evidence diagnostic plot") 

1102 

1103 fig, ax1 = plt.subplots() 

1104 if betas[-1] == 0: 

1105 x, y = betas[:-1], mean_lnlikes[:-1] 

1106 else: 

1107 x, y = betas, mean_lnlikes 

1108 if sem_lnlikes is not None: 

1109 ax1.errorbar(x, y, sem_lnlikes, fmt="-") 

1110 else: 

1111 ax1.plot(x, y, "-o") 

1112 ax1.set_xscale("log") 

1113 ax1.set_xlabel(r"$\beta$") 

1114 ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$") 

1115 

1116 plt.tight_layout() 

1117 fig.savefig("{}/{}_beta_lnl.png".format(outdir, label)) 

1118 plt.close() 

1119 

1120 def _create_stepping_stone_plot(self, means, outdir, label): 

1121 import matplotlib.pyplot as plt 

1122 

1123 logger.debug("Creating stepping stone evidence diagnostic plot") 

1124 

1125 n_steps = len(means) 

1126 

1127 fig, axes = plt.subplots(nrows=2, figsize=(8, 10)) 

1128 

1129 ax = axes[0] 

1130 ax.plot(np.arange(1, n_steps + 1), means) 

1131 ax.set_xlabel("$k$") 

1132 ax.set_ylabel("$r_{k}$") 

1133 

1134 ax = axes[1] 

1135 ax.plot(np.arange(1, n_steps + 1), np.cumsum(means[::1])[::1]) 

1136 ax.set_xlabel("$k$") 

1137 ax.set_ylabel("Cumulative $\\ln Z$") 

1138 

1139 plt.tight_layout() 

1140 fig.savefig("{}/{}_stepping_stone.png".format(outdir, label)) 

1141 plt.close() 

1142 

1143 @property 

1144 def rejection_sampling_count(self): 

1145 if self.pt_rejection_sample: 

1146 counts = 0 

1147 for column in self.sampler_list_of_tempered_lists: 

1148 for sampler in column: 

1149 counts += sampler.rejection_sampling_count 

1150 return counts 

1151 else: 

1152 return None 

1153 

1154 

1155class BilbyMCMCSampler(object): 

1156 def __init__( 

1157 self, 

1158 convergence_inputs, 

1159 proposal_cycle=None, 

1160 beta=1, 

1161 Tindex=0, 

1162 Eindex=0, 

1163 use_ratio=False, 

1164 initial_sample_method="prior", 

1165 initial_sample_dict=None, 

1166 normalize_prior=True, 

1167 ): 

1168 self.beta = beta 

1169 self.Tindex = Tindex 

1170 self.Eindex = Eindex 

1171 self.use_ratio = use_ratio 

1172 self.normalize_prior = normalize_prior 

1173 self.parameters = _sampling_convenience_dump.priors.non_fixed_keys 

1174 self.ndim = len(self.parameters) 

1175 

1176 if initial_sample_method.lower() == "prior": 

1177 full_sample_dict = _sampling_convenience_dump.priors.sample() 

1178 initial_sample = { 

1179 k: v 

1180 for k, v in full_sample_dict.items() 

1181 if k in _sampling_convenience_dump.priors.non_fixed_keys 

1182 } 

1183 elif initial_sample_method.lower() in ["maximize", "maximise", "maximum"]: 

1184 initial_sample = get_initial_maximimum_posterior_sample(self.beta) 

1185 else: 

1186 raise ValueError( 

1187 f"initial sample method {initial_sample_method} not understood" 

1188 ) 

1189 

1190 if initial_sample_dict is not None: 

1191 initial_sample.update(initial_sample_dict) 

1192 

1193 if self.beta == 1: 

1194 logger.info(f"Using initial sample {initial_sample}") 

1195 

1196 initial_sample = Sample(initial_sample) 

1197 initial_sample[LOGLKEY] = self.log_likelihood(initial_sample) 

1198 initial_sample[LOGPKEY] = self.log_prior(initial_sample) 

1199 

1200 self.chain = Chain(initial_sample=initial_sample) 

1201 self.set_convergence_inputs(convergence_inputs) 

1202 

1203 self.accepted = 0 

1204 self.rejected = 0 

1205 self.pt_accepted = 0 

1206 self.pt_rejected = 0 

1207 self.rejection_sampling_count = 0 

1208 

1209 if isinstance(proposal_cycle, str): 

1210 # Only print warnings for the primary sampler 

1211 if Tindex == 0 and Eindex == 0: 

1212 warn = True 

1213 else: 

1214 warn = False 

1215 

1216 self.proposal_cycle = proposals.get_proposal_cycle( 

1217 proposal_cycle, 

1218 _sampling_convenience_dump.priors, 

1219 L1steps=self.chain.L1steps, 

1220 warn=warn, 

1221 ) 

1222 elif isinstance(proposal_cycle, proposals.ProposalCycle): 

1223 self.proposal_cycle = proposal_cycle 

1224 else: 

1225 raise SamplerError("Proposal cycle not understood") 

1226 

1227 if self.Tindex == 0 and self.Eindex == 0: 

1228 logger.info(f"Using {self.proposal_cycle}") 

1229 

1230 def set_convergence_inputs(self, convergence_inputs): 

1231 for key, val in convergence_inputs._asdict().items(): 

1232 setattr(self.chain, key, val) 

1233 self.target_nsamples = convergence_inputs.target_nsamples 

1234 self.stop_after_convergence = convergence_inputs.stop_after_convergence 

1235 

1236 def log_likelihood(self, sample): 

1237 _sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict) 

1238 

1239 if self.use_ratio: 

1240 logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio() 

1241 else: 

1242 logl = _sampling_convenience_dump.likelihood.log_likelihood() 

1243 

1244 return logl 

1245 

1246 def log_prior(self, sample): 

1247 return _sampling_convenience_dump.priors.ln_prob( 

1248 sample.parameter_only_dict, 

1249 normalized=self.normalize_prior, 

1250 ) 

1251 

1252 def accept_proposal(self, prop, proposal): 

1253 self.chain.append(prop) 

1254 self.accepted += 1 

1255 proposal.accepted += 1 

1256 

1257 def reject_proposal(self, curr, proposal): 

1258 self.chain.append(curr) 

1259 self.rejected += 1 

1260 proposal.rejected += 1 

1261 

1262 def step(self): 

1263 if self.stop_after_convergence and self.chain.converged: 

1264 return self 

1265 

1266 internal_steps = 0 

1267 internal_accepted = 0 

1268 internal_rejected = 0 

1269 curr = self.chain.current_sample.copy() 

1270 while internal_steps < self.chain.L1steps: 

1271 internal_steps += 1 

1272 proposal = self.proposal_cycle.get_proposal() 

1273 prop, log_factor = proposal( 

1274 self.chain, 

1275 likelihood=_sampling_convenience_dump.likelihood, 

1276 priors=_sampling_convenience_dump.priors, 

1277 ) 

1278 logp = self.log_prior(prop) 

1279 

1280 if np.isinf(logp) or np.isnan(logp): 

1281 internal_rejected += 1 

1282 proposal.rejected += 1 

1283 continue 

1284 

1285 prop[LOGPKEY] = logp 

1286 prop[LOGLKEY] = self.log_likelihood(prop) 

1287 

1288 if np.isinf(prop[LOGLKEY]) or np.isnan(prop[LOGLKEY]): 

1289 internal_rejected += 1 

1290 proposal.rejected += 1 

1291 continue 

1292 

1293 with np.errstate(over="ignore"): 

1294 alpha = np.exp( 

1295 log_factor 

1296 + self.beta * prop[LOGLKEY] 

1297 + prop[LOGPKEY] 

1298 - self.beta * curr[LOGLKEY] 

1299 - curr[LOGPKEY] 

1300 ) 

1301 

1302 if random.rng.uniform(0, 1) <= alpha: 

1303 internal_accepted += 1 

1304 proposal.accepted += 1 

1305 curr = prop 

1306 self.chain.current_sample = curr 

1307 else: 

1308 internal_rejected += 1 

1309 proposal.rejected += 1 

1310 

1311 self.chain.append(curr) 

1312 self.rejected += internal_rejected 

1313 self.accepted += internal_accepted 

1314 return self 

1315 

1316 @property 

1317 def nsamples(self): 

1318 nsamples = self.chain.nsamples 

1319 if nsamples > self.target_nsamples and self.chain.converged is False: 

1320 logger.debug(f"Temperature {self.Tindex} chain reached convergence") 

1321 self.chain.converged = True 

1322 return nsamples 

1323 

1324 @property 

1325 def acceptance_ratio(self): 

1326 return self.accepted / (self.accepted + self.rejected) 

1327 

1328 @property 

1329 def samples(self): 

1330 if self.beta == 1: 

1331 return self.chain.samples 

1332 else: 

1333 return self.rejection_sample_zero_temperature_samples(print_message=True) 

1334 

1335 def rejection_sample_zero_temperature_samples(self, print_message=False): 

1336 beta = self.beta 

1337 chain = self.chain 

1338 hot_samples = pd.DataFrame( 

1339 chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys 

1340 ) 

1341 if len(hot_samples) == 0: 

1342 logger.debug( 

1343 f"Rejection sampling for Temp {self.Tindex} failed: " 

1344 "no usable hot samples" 

1345 ) 

1346 return hot_samples 

1347 

1348 # Pull out log likelihood 

1349 zerotemp_logl = hot_samples[LOGLKEY] 

1350 

1351 # Revert to true likelihood if needed 

1352 if _sampling_convenience_dump.use_ratio: 

1353 zerotemp_logl += ( 

1354 _sampling_convenience_dump.likelihood.noise_log_likelihood() 

1355 ) 

1356 

1357 # Calculate normalised weights 

1358 log_weights = (1 - beta) * zerotemp_logl 

1359 max_weight = np.max(log_weights) 

1360 unnormalised_weights = np.exp(log_weights - max_weight) 

1361 weights = unnormalised_weights / np.sum(unnormalised_weights) 

1362 

1363 # Rejection sample 

1364 samples = rejection_sample(hot_samples, weights) 

1365 

1366 # Logging 

1367 self.rejection_sampling_count = len(samples) 

1368 

1369 if print_message: 

1370 logger.info( 

1371 f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} " 

1372 f"yielded {len(samples)} samples" 

1373 ) 

1374 return samples 

1375 

1376 

1377def get_initial_maximimum_posterior_sample(beta): 

1378 """A method to attempt optimization of the maximum likelihood 

1379 

1380 This uses a simple scipy optimization approach, starting from a number 

1381 of draws from the prior to avoid problems with local optimization. 

1382 

1383 """ 

1384 logger.info("Finding initial maximum posterior estimate") 

1385 likelihood = _sampling_convenience_dump.likelihood 

1386 priors = _sampling_convenience_dump.priors 

1387 search_parameter_keys = _sampling_convenience_dump.search_parameter_keys 

1388 

1389 bounds = [] 

1390 for key in search_parameter_keys: 

1391 bounds.append((priors[key].minimum, priors[key].maximum)) 

1392 

1393 def neg_log_post(x): 

1394 sample = {key: val for key, val in zip(search_parameter_keys, x)} 

1395 ln_prior = priors.ln_prob(sample) 

1396 

1397 if np.isinf(ln_prior): 

1398 return -np.inf 

1399 

1400 likelihood.parameters.update(sample) 

1401 

1402 return -beta * likelihood.log_likelihood() - ln_prior 

1403 

1404 res = differential_evolution(neg_log_post, bounds, popsize=100, init="sobol") 

1405 if res.success: 

1406 sample = {key: val for key, val in zip(search_parameter_keys, res.x)} 

1407 logger.info(f"Initial maximum posterior estimate {sample}") 

1408 return sample 

1409 else: 

1410 raise ValueError("Failed to find initial maximum posterior estimate") 

1411 

1412 

1413# Methods used to aid parallelisation: 

1414 

1415 

1416def call_step(sampler): 

1417 sampler = sampler.step() 

1418 return sampler