Coverage for bilby/core/sampler/base_sampler.py: 76%

446 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import datetime 

2import os 

3import shutil 

4import signal 

5import sys 

6import tempfile 

7import time 

8 

9import attr 

10import numpy as np 

11from pandas import DataFrame 

12 

13from ..prior import Constraint, DeltaFunction, Prior, PriorDict 

14from ..result import Result, read_in_result 

15from ..utils import ( 

16 Counter, 

17 check_directory_exists_and_if_not_mkdir, 

18 command_line_args, 

19 logger, 

20) 

21from ..utils.random import seed as set_seed 

22 

23 

24@attr.s 

25class _SamplingContainer: 

26 """ 

27 A container class for objects that are stored independently in each thread 

28 for some samplers. 

29 

30 A single instance of this will appear in this module that can be access 

31 by the individual samplers. 

32 

33 This includes the: 

34 

35 - likelihood (bilby.core.likelihood.Likelihood) 

36 - priors (bilby.core.prior.PriorDict) 

37 - search_parameter_keys (list) 

38 - use_ratio (bool) 

39 """ 

40 

41 likelihood = attr.ib(default=None) 

42 priors = attr.ib(default=None) 

43 search_parameter_keys = attr.ib(default=None) 

44 use_ratio = attr.ib(default=False) 

45 

46 

47_sampling_convenience_dump = _SamplingContainer() 

48 

49 

50def _initialize_global_variables( 

51 likelihood, 

52 priors, 

53 search_parameter_keys, 

54 use_ratio, 

55): 

56 """ 

57 Store a global copy of the likelihood, priors, and search keys for 

58 multiprocessing. 

59 """ 

60 global _sampling_convenience_dump 

61 _sampling_convenience_dump.likelihood = likelihood 

62 _sampling_convenience_dump.priors = priors 

63 _sampling_convenience_dump.search_parameter_keys = search_parameter_keys 

64 _sampling_convenience_dump.use_ratio = use_ratio 

65 

66 

67def signal_wrapper(method): 

68 """ 

69 Decorator to wrap a method of a class to set system signals before running 

70 and reset them after. 

71 

72 Parameters 

73 ========== 

74 method: callable 

75 The method to call, this assumes the first argument is `self` 

76 and that `self` has a `write_current_state_and_exit` method. 

77 

78 Returns 

79 ======= 

80 output: callable 

81 The wrapped method. 

82 """ 

83 

84 def wrapped(self, *args, **kwargs): 

85 try: 

86 old_term = signal.signal(signal.SIGTERM, self.write_current_state_and_exit) 

87 old_int = signal.signal(signal.SIGINT, self.write_current_state_and_exit) 

88 old_alarm = signal.signal(signal.SIGALRM, self.write_current_state_and_exit) 

89 _set = True 

90 except (AttributeError, ValueError): 

91 _set = False 

92 logger.debug( 

93 "Setting signal attributes unavailable on this system. " 

94 "This is likely the case if you are running on a Windows machine " 

95 "and can be safely ignored." 

96 ) 

97 output = method(self, *args, **kwargs) 

98 if _set: 

99 signal.signal(signal.SIGTERM, old_term) 

100 signal.signal(signal.SIGINT, old_int) 

101 signal.signal(signal.SIGALRM, old_alarm) 

102 return output 

103 

104 return wrapped 

105 

106 

107class Sampler(object): 

108 """A sampler object to aid in setting up an inference run 

109 

110 Parameters 

111 ========== 

112 likelihood: likelihood.Likelihood 

113 A object with a log_l method 

114 priors: bilby.core.prior.PriorDict, dict 

115 Priors to be used in the search. 

116 This has attributes for each parameter to be sampled. 

117 external_sampler: str, Sampler, optional 

118 A string containing the module name of the sampler or an instance of 

119 this class 

120 outdir: str, optional 

121 Name of the output directory 

122 label: str, optional 

123 Naming scheme of the output files 

124 use_ratio: bool, optional 

125 Switch to set whether or not you want to use the log-likelihood ratio 

126 or just the log-likelihood 

127 plot: bool, optional 

128 Switch to set whether or not you want to create traceplots 

129 injection_parameters: 

130 A dictionary of the injection parameters 

131 meta_data: 

132 A dictionary of extra meta data to store in the result 

133 result_class: bilby.core.result.Result, or child of 

134 The result class to use. By default, `bilby.core.result.Result` is used, 

135 but objects which inherit from this class can be given providing 

136 additional methods. 

137 soft_init: bool, optional 

138 Switch to enable a soft initialization that prevents the likelihood 

139 from being tested before running the sampler. This is relevant when 

140 using custom likelihoods that must NOT be initialized on the main thread 

141 when using multiprocessing, e.g. when using tensorflow in the likelihood. 

142 **kwargs: dict 

143 Additional keyword arguments 

144 

145 Attributes 

146 ========== 

147 likelihood: likelihood.Likelihood 

148 A object with a log_l method 

149 priors: bilby.core.prior.PriorDict 

150 Priors to be used in the search. 

151 This has attributes for each parameter to be sampled. 

152 external_sampler: Module 

153 An external module containing an implementation of a sampler. 

154 outdir: str 

155 Name of the output directory 

156 label: str 

157 Naming scheme of the output files 

158 use_ratio: bool 

159 Switch to set whether or not you want to use the log-likelihood ratio 

160 or just the log-likelihood 

161 plot: bool 

162 Switch to set whether or not you want to create traceplots 

163 skip_import_verification: bool 

164 Skips the check if the sampler is installed if true. This is 

165 only advisable for testing environments 

166 result: bilby.core.result.Result 

167 Container for the results of the sampling run 

168 exit_code: int 

169 System exit code to return on interrupt 

170 kwargs: dict 

171 Dictionary of keyword arguments that can be used in the external sampler 

172 hard_exit: bool 

173 Whether the implemented sampler exits hard (:code:`os._exit` rather 

174 than :code:`sys.exit`). The latter can be escaped as :code:`SystemExit`. 

175 The former cannot. 

176 sampler_name : str 

177 Name of the sampler. This is used when creating the output directory for 

178 the sampler. 

179 abbreviation : str 

180 Abbreviated name of the sampler. Does not have to be specified in child 

181 classes. If set to a value other than :code:`None`, this will be used 

182 instead of :code:`sampler_name` when creating the output directory. 

183 

184 Raises 

185 ====== 

186 TypeError: 

187 If external_sampler is neither a string nor an instance of this class 

188 If not all likelihood.parameters have been defined 

189 ImportError: 

190 If the external_sampler string does not refer to a sampler that is 

191 installed on this system 

192 AttributeError: 

193 If some of the priors can't be sampled 

194 

195 """ 

196 

197 sampler_name = "sampler" 

198 abbreviation = None 

199 default_kwargs = dict() 

200 npool_equiv_kwargs = [ 

201 "npool", 

202 "queue_size", 

203 "threads", 

204 "nthreads", 

205 "cores", 

206 "n_pool", 

207 ] 

208 sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"] 

209 hard_exit = False 

210 sampling_seed_key = None 

211 """Name of keyword argument for setting the sampling for the specific sampler. 

212 If a specific sampler does not have a sampling seed option, then it should be 

213 left as None. 

214 """ 

215 check_point_equiv_kwargs = ["check_point_deltaT", "check_point_delta_t"] 

216 

217 def __init__( 

218 self, 

219 likelihood, 

220 priors, 

221 outdir="outdir", 

222 label="label", 

223 use_ratio=False, 

224 plot=False, 

225 skip_import_verification=False, 

226 injection_parameters=None, 

227 meta_data=None, 

228 result_class=None, 

229 likelihood_benchmark=False, 

230 soft_init=False, 

231 exit_code=130, 

232 npool=1, 

233 **kwargs, 

234 ): 

235 self.likelihood = likelihood 

236 if isinstance(priors, PriorDict): 

237 self.priors = priors 

238 else: 

239 self.priors = PriorDict(priors) 

240 self.label = label 

241 self.outdir = outdir 

242 self.injection_parameters = injection_parameters 

243 self.meta_data = meta_data 

244 self.use_ratio = use_ratio 

245 self._npool = npool 

246 if not skip_import_verification: 

247 self._verify_external_sampler() 

248 self.external_sampler_function = None 

249 self.plot = plot 

250 self.likelihood_benchmark = likelihood_benchmark 

251 

252 self._search_parameter_keys = list() 

253 self._fixed_parameter_keys = list() 

254 self._constraint_parameter_keys = list() 

255 self._initialise_parameters() 

256 self._log_information_about_priors_and_likelihood() 

257 

258 self.exit_code = exit_code 

259 

260 self._log_likelihood_eval_time = np.nan 

261 if not soft_init: 

262 self._verify_parameters() 

263 self._log_likelihood_eval_time = self._time_likelihood() 

264 self._verify_use_ratio() 

265 

266 self.kwargs = kwargs 

267 

268 self._check_cached_result(result_class) 

269 

270 self._log_summary_for_sampler() 

271 

272 self.result = self._initialise_result(result_class) 

273 self.likelihood_count = None 

274 if self.likelihood_benchmark: 

275 self.likelihood_count = Counter() 

276 

277 @property 

278 def search_parameter_keys(self): 

279 """list: List of parameter keys that are being sampled""" 

280 return self._search_parameter_keys 

281 

282 @property 

283 def fixed_parameter_keys(self): 

284 """list: List of parameter keys that are not being sampled""" 

285 return self._fixed_parameter_keys 

286 

287 @property 

288 def constraint_parameter_keys(self): 

289 """list: List of parameters providing prior constraints""" 

290 return self._constraint_parameter_keys 

291 

292 @property 

293 def ndim(self): 

294 """int: Number of dimensions of the search parameter space""" 

295 return len(self._search_parameter_keys) 

296 

297 @property 

298 def kwargs(self): 

299 """dict: Container for the kwargs. Has more sophisticated logic in subclasses""" 

300 return self._kwargs 

301 

302 @kwargs.setter 

303 def kwargs(self, kwargs): 

304 self._kwargs = self.default_kwargs.copy() 

305 self._translate_kwargs(kwargs) 

306 self._kwargs.update(kwargs) 

307 self._verify_kwargs_against_default_kwargs() 

308 

309 def _translate_kwargs(self, kwargs): 

310 """Translate keyword arguments. 

311 

312 Default only translates the sampling seed if the sampler has 

313 :code:`sampling_seed_key` set. 

314 """ 

315 if self.sampling_seed_key and self.sampling_seed_key not in kwargs: 

316 for equiv in self.sampling_seed_equiv_kwargs: 

317 if equiv in kwargs: 

318 kwargs[self.sampling_seed_key] = kwargs.pop(equiv) 

319 set_seed(kwargs[self.sampling_seed_key]) 

320 return kwargs 

321 

322 @property 

323 def external_sampler_name(self): 

324 return self.__class__.__name__.lower() 

325 

326 def _verify_external_sampler(self): 

327 external_sampler_name = self.external_sampler_name 

328 try: 

329 __import__(external_sampler_name) 

330 except (ImportError, SystemExit): 

331 raise SamplerNotInstalledError( 

332 f"Sampler {external_sampler_name} is not installed on this system" 

333 ) 

334 

335 def _verify_kwargs_against_default_kwargs(self): 

336 """ 

337 Check if the kwargs are contained in the list of available arguments 

338 of the external sampler. 

339 """ 

340 args = self.default_kwargs 

341 bad_keys = [] 

342 for user_input in self.kwargs.keys(): 

343 if user_input not in args: 

344 logger.warning( 

345 f"Supplied argument '{user_input}' not an argument of '{self.__class__.__name__}', removing." 

346 ) 

347 bad_keys.append(user_input) 

348 for key in bad_keys: 

349 self.kwargs.pop(key) 

350 

351 def _initialise_parameters(self): 

352 """ 

353 Go through the list of priors and add keys to the fixed and search 

354 parameter key list depending on whether 

355 the respective parameter is fixed. 

356 """ 

357 for key in self.priors: 

358 if ( 

359 isinstance(self.priors[key], Prior) 

360 and self.priors[key].is_fixed is False 

361 ): 

362 self._search_parameter_keys.append(key) 

363 elif isinstance(self.priors[key], Constraint): 

364 self._constraint_parameter_keys.append(key) 

365 elif isinstance(self.priors[key], DeltaFunction): 

366 self.likelihood.parameters[key] = self.priors[key].sample() 

367 self._fixed_parameter_keys.append(key) 

368 

369 def _log_information_about_priors_and_likelihood(self): 

370 logger.info("Analysis priors:") 

371 for key in self._search_parameter_keys + self._constraint_parameter_keys: 

372 logger.info(f"{key}={self.priors[key]}") 

373 for key in self._fixed_parameter_keys: 

374 logger.info(f"{key}={self.priors[key].peak}") 

375 logger.info(f"Analysis likelihood class: {self.likelihood.__class__}") 

376 logger.info( 

377 f"Analysis likelihood noise evidence: {self.likelihood.noise_log_likelihood()}" 

378 ) 

379 

380 def _initialise_result(self, result_class): 

381 """ 

382 Returns 

383 ======= 

384 bilby.core.result.Result: An initial template for the result 

385 

386 """ 

387 result_kwargs = dict( 

388 label=self.label, 

389 outdir=self.outdir, 

390 sampler=self.__class__.__name__.lower(), 

391 search_parameter_keys=self._search_parameter_keys, 

392 fixed_parameter_keys=self._fixed_parameter_keys, 

393 constraint_parameter_keys=self._constraint_parameter_keys, 

394 priors=self.priors, 

395 meta_data=self.meta_data, 

396 injection_parameters=self.injection_parameters, 

397 sampler_kwargs=self.kwargs, 

398 use_ratio=self.use_ratio, 

399 ) 

400 

401 if result_class is None: 

402 result = Result(**result_kwargs) 

403 elif issubclass(result_class, Result): 

404 result = result_class(**result_kwargs) 

405 else: 

406 raise ValueError(f"Input result_class={result_class} not understood") 

407 

408 return result 

409 

410 def _verify_parameters(self): 

411 """Evaluate a set of parameters drawn from the prior 

412 

413 Tests if the likelihood evaluation passes 

414 

415 Raises 

416 ====== 

417 TypeError 

418 Likelihood can't be evaluated. 

419 

420 """ 

421 

422 if self.priors.test_has_redundant_keys(): 

423 raise IllegalSamplingSetError( 

424 "Your sampling set contains redundant parameters." 

425 ) 

426 

427 theta = self.priors.sample_subset_constrained_as_array( 

428 self.search_parameter_keys, size=1 

429 )[:, 0] 

430 try: 

431 self.log_likelihood(theta) 

432 except TypeError as e: 

433 raise TypeError( 

434 f"Likelihood evaluation failed with message: \n'{e}'\n" 

435 f"Have you specified all the parameters:\n{self.likelihood.parameters}" 

436 ) 

437 

438 def _time_likelihood(self, n_evaluations=100): 

439 """Times the likelihood evaluation and print an info message 

440 

441 Parameters 

442 ========== 

443 n_evaluations: int 

444 The number of evaluations to estimate the evaluation time from 

445 

446 Returns 

447 ======= 

448 log_likelihood_eval_time: float 

449 The time (in s) it took for one likelihood evaluation 

450 """ 

451 

452 t1 = datetime.datetime.now() 

453 for _ in range(n_evaluations): 

454 theta = self.priors.sample_subset_constrained_as_array( 

455 self._search_parameter_keys, size=1 

456 )[:, 0] 

457 self.log_likelihood(theta) 

458 total_time = (datetime.datetime.now() - t1).total_seconds() 

459 log_likelihood_eval_time = total_time / n_evaluations 

460 

461 if log_likelihood_eval_time == 0: 

462 log_likelihood_eval_time = np.nan 

463 logger.info("Unable to measure single likelihood time") 

464 else: 

465 logger.info( 

466 f"Single likelihood evaluation took {log_likelihood_eval_time:.3e} s" 

467 ) 

468 return log_likelihood_eval_time 

469 

470 def _verify_use_ratio(self): 

471 """ 

472 Checks if use_ratio is set. Prints a warning if use_ratio is set but 

473 not properly implemented. 

474 """ 

475 try: 

476 self.priors.sample_subset(self.search_parameter_keys) 

477 except (KeyError, AttributeError): 

478 logger.error( 

479 f"Cannot sample from priors with keys: {self.search_parameter_keys}." 

480 ) 

481 raise 

482 if self.use_ratio is False: 

483 logger.debug("use_ratio set to False") 

484 return 

485 

486 ratio_is_nan = np.isnan(self.likelihood.log_likelihood_ratio()) 

487 

488 if self.use_ratio is True and ratio_is_nan: 

489 logger.warning( 

490 "You have requested to use the loglikelihood_ratio, but it " 

491 " returns a NaN" 

492 ) 

493 elif self.use_ratio is None and not ratio_is_nan: 

494 logger.debug("use_ratio not spec. but gives valid answer, setting True") 

495 self.use_ratio = True 

496 

497 def prior_transform(self, theta): 

498 """Prior transform method that is passed into the external sampler. 

499 

500 Parameters 

501 ========== 

502 theta: list 

503 List of sampled values on a unit interval 

504 

505 Returns 

506 ======= 

507 list: Properly rescaled sampled values 

508 """ 

509 return self.priors.rescale(self._search_parameter_keys, theta) 

510 

511 def log_prior(self, theta): 

512 """ 

513 

514 Parameters 

515 ========== 

516 theta: list 

517 List of sampled values on a unit interval 

518 

519 Returns 

520 ======= 

521 float: Joint ln prior probability of theta 

522 

523 """ 

524 params = {key: t for key, t in zip(self._search_parameter_keys, theta)} 

525 return self.priors.ln_prob(params) 

526 

527 def log_likelihood(self, theta): 

528 """ 

529 

530 Parameters 

531 ========== 

532 theta: list 

533 List of values for the likelihood parameters 

534 

535 Returns 

536 ======= 

537 float: Log-likelihood or log-likelihood-ratio given the current 

538 likelihood.parameter values 

539 

540 """ 

541 if self.likelihood_benchmark: 

542 try: 

543 self.likelihood_count.increment() 

544 except AttributeError: 

545 pass 

546 params = {key: t for key, t in zip(self._search_parameter_keys, theta)} 

547 self.likelihood.parameters.update(params) 

548 if self.use_ratio: 

549 return self.likelihood.log_likelihood_ratio() 

550 else: 

551 return self.likelihood.log_likelihood() 

552 

553 def get_random_draw_from_prior(self): 

554 """Get a random draw from the prior distribution 

555 

556 Returns 

557 ======= 

558 draw: array_like 

559 An ndim-length array of values drawn from the prior. Parameters 

560 with delta-function (or fixed) priors are not returned 

561 

562 """ 

563 new_sample = self.priors.sample() 

564 draw = np.array(list(new_sample[key] for key in self._search_parameter_keys)) 

565 self.check_draw(draw) 

566 return draw 

567 

568 def get_initial_points_from_prior(self, npoints=1): 

569 """Method to draw a set of live points from the prior 

570 

571 This iterates over draws from the prior until all the samples have a 

572 finite prior and likelihood (relevant for constrained priors). 

573 

574 Parameters 

575 ========== 

576 npoints: int 

577 The number of values to return 

578 

579 Returns 

580 ======= 

581 unit_cube, parameters, likelihood: tuple of array_like 

582 unit_cube (nlive, ndim) is an array of the prior samples from the 

583 unit cube, parameters (nlive, ndim) is the unit_cube array 

584 transformed to the target space, while likelihood (nlive) are the 

585 likelihood evaluations. 

586 

587 """ 

588 from ..utils.random import rng 

589 

590 logger.info("Generating initial points from the prior") 

591 unit_cube = [] 

592 parameters = [] 

593 likelihood = [] 

594 while len(unit_cube) < npoints: 

595 unit = rng.uniform(0, 1, self.ndim) 

596 theta = self.prior_transform(unit) 

597 if self.check_draw(theta, warning=False): 

598 unit_cube.append(unit) 

599 parameters.append(theta) 

600 likelihood.append(self.log_likelihood(theta)) 

601 

602 return np.array(unit_cube), np.array(parameters), np.array(likelihood) 

603 

604 def check_draw(self, theta, warning=True): 

605 """ 

606 Checks if the draw will generate an infinite prior or likelihood 

607 

608 Also catches the output of `numpy.nan_to_num`. 

609 

610 Parameters 

611 ========== 

612 theta: array_like 

613 Parameter values at which to evaluate likelihood 

614 warning: bool 

615 Whether or not to print a warning 

616 

617 Returns 

618 ======= 

619 bool, cube (nlive, 

620 True if the likelihood and prior are finite, false otherwise 

621 

622 """ 

623 log_p = self.log_prior(theta) 

624 log_l = self.log_likelihood(theta) 

625 return self._check_bad_value( 

626 val=log_p, warning=warning, theta=theta, label="prior" 

627 ) and self._check_bad_value( 

628 val=log_l, warning=warning, theta=theta, label="likelihood" 

629 ) 

630 

631 @staticmethod 

632 def _check_bad_value(val, warning, theta, label): 

633 val = np.abs(val) 

634 bad_values = [np.inf, np.nan_to_num(np.inf)] 

635 if val in bad_values or np.isnan(val): 

636 if warning: 

637 logger.warning(f"Prior draw {theta} has inf {label}") 

638 return False 

639 return True 

640 

641 def run_sampler(self): 

642 """A template method to run in subclasses""" 

643 pass 

644 

645 def _run_test(self): 

646 """ 

647 TODO: Implement this method 

648 Raises 

649 ======= 

650 ValueError: in any case 

651 """ 

652 raise ValueError("Method not yet implemented") 

653 

654 def _check_cached_result(self, result_class=None): 

655 """Check if the cached data file exists and can be used""" 

656 

657 if command_line_args.clean: 

658 logger.debug("Command line argument clean given, forcing rerun") 

659 self.cached_result = None 

660 return 

661 

662 try: 

663 self.cached_result = read_in_result( 

664 outdir=self.outdir, label=self.label, result_class=result_class 

665 ) 

666 except IOError: 

667 self.cached_result = None 

668 

669 if command_line_args.use_cached: 

670 logger.debug("Command line argument cached given, no cache check performed") 

671 return 

672 

673 logger.debug("Checking cached data") 

674 if self.cached_result: 

675 check_keys = ["search_parameter_keys", "fixed_parameter_keys"] 

676 use_cache = True 

677 for key in check_keys: 

678 if ( 

679 self.cached_result._check_attribute_match_to_other_object(key, self) 

680 is False 

681 ): 

682 logger.debug(f"Cached value {key} is unmatched") 

683 use_cache = False 

684 try: 

685 # Recursive check the dictionaries allowing for numpy arrays 

686 np.testing.assert_equal( 

687 self.meta_data["likelihood"], 

688 self.cached_result.meta_data["likelihood"], 

689 ) 

690 except AssertionError: 

691 use_cache = False 

692 if use_cache is False: 

693 self.cached_result = None 

694 

695 def _log_summary_for_sampler(self): 

696 """Print a summary of the sampler used and its kwargs""" 

697 if self.cached_result is None: 

698 kwargs_print = self.kwargs.copy() 

699 for k in kwargs_print: 

700 if isinstance(kwargs_print[k], (list, np.ndarray)): 

701 array_repr = np.array(kwargs_print[k]) 

702 if array_repr.size > 10: 

703 kwargs_print[k] = f"array_like, shape={array_repr.shape}" 

704 elif isinstance(kwargs_print[k], DataFrame): 

705 kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}" 

706 logger.info( 

707 f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}" 

708 ) 

709 

710 def calc_likelihood_count(self): 

711 if self.likelihood_benchmark: 

712 self.result.num_likelihood_evaluations = self.likelihood_count.value 

713 else: 

714 return None 

715 

716 @property 

717 def npool(self): 

718 for key in self.npool_equiv_kwargs: 

719 if key in self.kwargs: 

720 return self.kwargs[key] 

721 return self._npool 

722 

723 def _log_interruption(self, signum=None): 

724 if signum == 14: 

725 logger.info( 

726 f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}" 

727 ) 

728 else: 

729 logger.info( 

730 f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}" 

731 ) 

732 

733 def write_current_state_and_exit(self, signum=None, frame=None): 

734 """ 

735 Make sure that if a pool of jobs is running only the parent tries to 

736 checkpoint and exit. Only the parent has a 'pool' attribute. 

737 

738 For samplers that must hard exit (typically due to non-Python process) 

739 use :code:`os._exit` that cannot be excepted. Other samplers exiting 

740 can be caught as a :code:`SystemExit`. 

741 """ 

742 if self.npool in (1, None) or getattr(self, "pool", None) is not None: 

743 self._log_interruption(signum=signum) 

744 self.write_current_state() 

745 self._close_pool() 

746 if self.hard_exit: 

747 os._exit(self.exit_code) 

748 else: 

749 sys.exit(self.exit_code) 

750 

751 def _close_pool(self): 

752 if getattr(self, "pool", None) is not None: 

753 logger.info("Starting to close worker pool.") 

754 self.pool.close() 

755 self.pool.join() 

756 self.pool = None 

757 self.kwargs["pool"] = self.pool 

758 logger.info("Finished closing worker pool.") 

759 

760 def _setup_pool(self): 

761 if self.kwargs.get("pool", None) is not None: 

762 logger.info("Using user defined pool.") 

763 self.pool = self.kwargs["pool"] 

764 elif self.npool is not None and self.npool > 1: 

765 logger.info(f"Setting up multiproccesing pool with {self.npool} processes") 

766 import multiprocessing 

767 

768 self.pool = multiprocessing.Pool( 

769 processes=self.npool, 

770 initializer=_initialize_global_variables, 

771 initargs=( 

772 self.likelihood, 

773 self.priors, 

774 self._search_parameter_keys, 

775 self.use_ratio, 

776 ), 

777 ) 

778 else: 

779 self.pool = None 

780 _initialize_global_variables( 

781 likelihood=self.likelihood, 

782 priors=self.priors, 

783 search_parameter_keys=self._search_parameter_keys, 

784 use_ratio=self.use_ratio, 

785 ) 

786 self.kwargs["pool"] = self.pool 

787 

788 def write_current_state(self): 

789 raise NotImplementedError() 

790 

791 @classmethod 

792 def get_expected_outputs(cls, outdir=None, label=None): 

793 """Get lists of the expected outputs directories and files. 

794 

795 These are used by :code:`bilby_pipe` when transferring files via HTCondor. 

796 Both can be empty. Defaults to a single directory: 

797 :code:`"{outdir}/{name}_{label}/"`, where :code:`name` 

798 is :code:`abbreviation` if it is defined for the sampler class, otherwise 

799 it defaults to :code:`sampler_name`. 

800 

801 Parameters 

802 ---------- 

803 outdir : str 

804 The output directory. 

805 label : str 

806 The label for the run. 

807 

808 Returns 

809 ------- 

810 list 

811 List of file names. 

812 list 

813 List of directory names. 

814 """ 

815 name = cls.abbreviation or cls.sampler_name 

816 dirname = os.path.join(outdir, f"{name}_{label}", "") 

817 return [], [dirname] 

818 

819 

820class NestedSampler(Sampler): 

821 sampler_name = "nested_sampler" 

822 npoints_equiv_kwargs = [ 

823 "nlive", 

824 "nlives", 

825 "n_live_points", 

826 "npoints", 

827 "npoint", 

828 "Nlive", 

829 "num_live_points", 

830 "num_particles", 

831 ] 

832 walks_equiv_kwargs = ["walks", "steps", "nmcmc"] 

833 

834 @staticmethod 

835 def reorder_loglikelihoods( 

836 unsorted_loglikelihoods, unsorted_samples, sorted_samples 

837 ): 

838 """Reorders the stored log-likelihood after they have been reweighted 

839 

840 This creates a sorting index by matching the reweights `result.samples` 

841 against the raw samples, then uses this index to sort the 

842 loglikelihoods 

843 

844 Parameters 

845 ========== 

846 sorted_samples, unsorted_samples: array-like 

847 Sorted and unsorted values of the samples. These should be of the 

848 same shape and contain the same sample values, but in different 

849 orders 

850 unsorted_loglikelihoods: array-like 

851 The loglikelihoods corresponding to the unsorted_samples 

852 

853 Returns 

854 ======= 

855 sorted_loglikelihoods: array-like 

856 The loglikelihoods reordered to match that of the sorted_samples 

857 

858 

859 """ 

860 

861 idxs = [] 

862 for ii in range(len(unsorted_loglikelihoods)): 

863 idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0] 

864 if len(idx) > 1: 

865 logger.warning( 

866 "Multiple likelihood matches found between sorted and " 

867 "unsorted samples. Taking the first match." 

868 ) 

869 idxs.append(idx[0]) 

870 return unsorted_loglikelihoods[idxs] 

871 

872 def log_likelihood(self, theta): 

873 """ 

874 Since some nested samplers don't call the log_prior method, evaluate 

875 the prior constraint here. 

876 

877 Parameters 

878 ========== 

879 theta: array_like 

880 Parameter values at which to evaluate likelihood 

881 

882 Returns 

883 ======= 

884 float: log_likelihood 

885 """ 

886 if self.priors.evaluate_constraints( 

887 {key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)} 

888 ): 

889 return Sampler.log_likelihood(self, theta) 

890 else: 

891 return np.nan_to_num(-np.inf) 

892 

893 

894class MCMCSampler(Sampler): 

895 sampler_name = "mcmc_sampler" 

896 nwalkers_equiv_kwargs = ["nwalker", "nwalkers", "draws", "Niter"] 

897 nburn_equiv_kwargs = ["burn", "nburn"] 

898 

899 def print_nburn_logging_info(self): 

900 """Prints logging info as to how nburn was calculated""" 

901 if type(self.nburn) in [float, int]: 

902 logger.info(f"Discarding {self.nburn} steps for burn-in") 

903 elif self.result.max_autocorrelation_time is None: 

904 logger.info( 

905 f"Autocorrelation time not calculated, discarding " 

906 f"{self.nburn} steps for burn-in" 

907 ) 

908 else: 

909 logger.info( 

910 f"Discarding {self.nburn} steps for burn-in, estimated from autocorr" 

911 ) 

912 

913 def calculate_autocorrelation(self, samples, c=3): 

914 """Uses the `emcee.autocorr` module to estimate the autocorrelation 

915 

916 Parameters 

917 ========== 

918 samples: array_like 

919 A chain of samples. 

920 c: float 

921 The minimum number of autocorrelation times needed to trust the 

922 estimate (default: `3`). See `emcee.autocorr.integrated_time`. 

923 """ 

924 import emcee 

925 

926 try: 

927 self.result.max_autocorrelation_time = int( 

928 np.max(emcee.autocorr.integrated_time(samples, c=c)) 

929 ) 

930 logger.info(f"Max autocorr time = {self.result.max_autocorrelation_time}") 

931 except emcee.autocorr.AutocorrError as e: 

932 self.result.max_autocorrelation_time = None 

933 logger.info(f"Unable to calculate autocorr time: {e}") 

934 

935 

936class _TemporaryFileSamplerMixin: 

937 """ 

938 A mixin class to handle storing sampler intermediate products in a temporary 

939 location. See, e.g., `this SO <https://stackoverflow.com/a/547714>` for a 

940 basic background on mixins. 

941 

942 This class makes sure that any subclasses can seamlessly use the temporary 

943 file functionality. 

944 """ 

945 

946 short_name = "" 

947 

948 def __init__(self, temporary_directory, **kwargs): 

949 super(_TemporaryFileSamplerMixin, self).__init__(**kwargs) 

950 try: 

951 from mpi4py import MPI 

952 

953 using_mpi = MPI.COMM_WORLD.Get_size() > 1 

954 except ImportError: 

955 using_mpi = False 

956 

957 if using_mpi and temporary_directory: 

958 logger.info( 

959 "Temporary directory incompatible with MPI, " 

960 "will run in original directory" 

961 ) 

962 self.use_temporary_directory = temporary_directory and not using_mpi 

963 self._outputfiles_basename = None 

964 self._temporary_outputfiles_basename = None 

965 

966 def _check_and_load_sampling_time_file(self): 

967 if os.path.exists(self.time_file_path): 

968 with open(self.time_file_path, "r") as time_file: 

969 self.total_sampling_time = float(time_file.readline()) 

970 else: 

971 self.total_sampling_time = 0 

972 

973 def _calculate_and_save_sampling_time(self): 

974 current_time = time.time() 

975 new_sampling_time = current_time - self.start_time 

976 self.total_sampling_time += new_sampling_time 

977 

978 with open(self.time_file_path, "w") as time_file: 

979 time_file.write(str(self.total_sampling_time)) 

980 

981 self.start_time = current_time 

982 

983 def _clean_up_run_directory(self): 

984 if self.use_temporary_directory: 

985 self._move_temporary_directory_to_proper_path() 

986 self.kwargs["outputfiles_basename"] = self.outputfiles_basename 

987 

988 @property 

989 def outputfiles_basename(self): 

990 return self._outputfiles_basename 

991 

992 @outputfiles_basename.setter 

993 def outputfiles_basename(self, outputfiles_basename): 

994 if outputfiles_basename is None: 

995 outputfiles_basename = f"{self.outdir}/{self.short_name}_{self.label}/" 

996 if not outputfiles_basename.endswith("/"): 

997 outputfiles_basename += "/" 

998 check_directory_exists_and_if_not_mkdir(self.outdir) 

999 self._outputfiles_basename = outputfiles_basename 

1000 

1001 @property 

1002 def temporary_outputfiles_basename(self): 

1003 return self._temporary_outputfiles_basename 

1004 

1005 @temporary_outputfiles_basename.setter 

1006 def temporary_outputfiles_basename(self, temporary_outputfiles_basename): 

1007 if not temporary_outputfiles_basename.endswith("/"): 

1008 temporary_outputfiles_basename += "/" 

1009 self._temporary_outputfiles_basename = temporary_outputfiles_basename 

1010 if os.path.exists(self.outputfiles_basename): 

1011 shutil.copytree( 

1012 self.outputfiles_basename, self.temporary_outputfiles_basename 

1013 ) 

1014 

1015 def write_current_state(self): 

1016 self._calculate_and_save_sampling_time() 

1017 if self.use_temporary_directory: 

1018 self._move_temporary_directory_to_proper_path() 

1019 

1020 def _move_temporary_directory_to_proper_path(self): 

1021 """ 

1022 Move the temporary back to the proper path 

1023 

1024 Anything in the proper path at this point is removed including links 

1025 """ 

1026 self._copy_temporary_directory_contents_to_proper_path() 

1027 shutil.rmtree(self.temporary_outputfiles_basename) 

1028 

1029 def _copy_temporary_directory_contents_to_proper_path(self): 

1030 """ 

1031 Copy the temporary back to the proper path. 

1032 Do not delete the temporary directory. 

1033 """ 

1034 logger.info( 

1035 f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}" 

1036 ) 

1037 outputfiles_basename_stripped = self.outputfiles_basename.rstrip("/") 

1038 shutil.copytree( 

1039 self.temporary_outputfiles_basename, 

1040 outputfiles_basename_stripped, 

1041 dirs_exist_ok=True, 

1042 ) 

1043 

1044 def _setup_run_directory(self): 

1045 """ 

1046 If using a temporary directory, the output directory is moved to the 

1047 temporary directory. 

1048 Used for Dnest4, Pymultinest, and Ultranest. 

1049 """ 

1050 check_directory_exists_and_if_not_mkdir(self.outputfiles_basename) 

1051 if self.use_temporary_directory: 

1052 temporary_outputfiles_basename = tempfile.TemporaryDirectory().name 

1053 self.temporary_outputfiles_basename = temporary_outputfiles_basename 

1054 

1055 if os.path.exists(self.outputfiles_basename): 

1056 shutil.copytree( 

1057 self.outputfiles_basename, 

1058 self.temporary_outputfiles_basename, 

1059 dirs_exist_ok=True, 

1060 ) 

1061 check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) 

1062 

1063 self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename 

1064 logger.info(f"Using temporary file {temporary_outputfiles_basename}") 

1065 else: 

1066 self.kwargs["outputfiles_basename"] = self.outputfiles_basename 

1067 logger.info(f"Using output file {self.outputfiles_basename}") 

1068 self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat" 

1069 

1070 

1071class Error(Exception): 

1072 """Base class for all exceptions raised by this module""" 

1073 

1074 

1075class SamplerError(Error): 

1076 """Base class for Error related to samplers in this module""" 

1077 

1078 

1079class ResumeError(Error): 

1080 """Class for errors arising from resuming runs""" 

1081 

1082 

1083class SamplerNotInstalledError(SamplerError): 

1084 """Base class for Error raised by not installed samplers""" 

1085 

1086 

1087class IllegalSamplingSetError(Error): 

1088 """Class for illegal sets of sampling parameters""" 

1089 

1090 

1091class SamplingMarginalisedParameterError(IllegalSamplingSetError): 

1092 """Class for errors that occur when sampling over marginalized parameters"""