Coverage for bilby/core/sampler/base_sampler.py: 76%
446 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import os
3import shutil
4import signal
5import sys
6import tempfile
7import time
9import attr
10import numpy as np
11from pandas import DataFrame
13from ..prior import Constraint, DeltaFunction, Prior, PriorDict
14from ..result import Result, read_in_result
15from ..utils import (
16 Counter,
17 check_directory_exists_and_if_not_mkdir,
18 command_line_args,
19 logger,
20)
21from ..utils.random import seed as set_seed
24@attr.s
25class _SamplingContainer:
26 """
27 A container class for objects that are stored independently in each thread
28 for some samplers.
30 A single instance of this will appear in this module that can be access
31 by the individual samplers.
33 This includes the:
35 - likelihood (bilby.core.likelihood.Likelihood)
36 - priors (bilby.core.prior.PriorDict)
37 - search_parameter_keys (list)
38 - use_ratio (bool)
39 """
41 likelihood = attr.ib(default=None)
42 priors = attr.ib(default=None)
43 search_parameter_keys = attr.ib(default=None)
44 use_ratio = attr.ib(default=False)
47_sampling_convenience_dump = _SamplingContainer()
50def _initialize_global_variables(
51 likelihood,
52 priors,
53 search_parameter_keys,
54 use_ratio,
55):
56 """
57 Store a global copy of the likelihood, priors, and search keys for
58 multiprocessing.
59 """
60 global _sampling_convenience_dump
61 _sampling_convenience_dump.likelihood = likelihood
62 _sampling_convenience_dump.priors = priors
63 _sampling_convenience_dump.search_parameter_keys = search_parameter_keys
64 _sampling_convenience_dump.use_ratio = use_ratio
67def signal_wrapper(method):
68 """
69 Decorator to wrap a method of a class to set system signals before running
70 and reset them after.
72 Parameters
73 ==========
74 method: callable
75 The method to call, this assumes the first argument is `self`
76 and that `self` has a `write_current_state_and_exit` method.
78 Returns
79 =======
80 output: callable
81 The wrapped method.
82 """
84 def wrapped(self, *args, **kwargs):
85 try:
86 old_term = signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
87 old_int = signal.signal(signal.SIGINT, self.write_current_state_and_exit)
88 old_alarm = signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
89 _set = True
90 except (AttributeError, ValueError):
91 _set = False
92 logger.debug(
93 "Setting signal attributes unavailable on this system. "
94 "This is likely the case if you are running on a Windows machine "
95 "and can be safely ignored."
96 )
97 output = method(self, *args, **kwargs)
98 if _set:
99 signal.signal(signal.SIGTERM, old_term)
100 signal.signal(signal.SIGINT, old_int)
101 signal.signal(signal.SIGALRM, old_alarm)
102 return output
104 return wrapped
107class Sampler(object):
108 """A sampler object to aid in setting up an inference run
110 Parameters
111 ==========
112 likelihood: likelihood.Likelihood
113 A object with a log_l method
114 priors: bilby.core.prior.PriorDict, dict
115 Priors to be used in the search.
116 This has attributes for each parameter to be sampled.
117 external_sampler: str, Sampler, optional
118 A string containing the module name of the sampler or an instance of
119 this class
120 outdir: str, optional
121 Name of the output directory
122 label: str, optional
123 Naming scheme of the output files
124 use_ratio: bool, optional
125 Switch to set whether or not you want to use the log-likelihood ratio
126 or just the log-likelihood
127 plot: bool, optional
128 Switch to set whether or not you want to create traceplots
129 injection_parameters:
130 A dictionary of the injection parameters
131 meta_data:
132 A dictionary of extra meta data to store in the result
133 result_class: bilby.core.result.Result, or child of
134 The result class to use. By default, `bilby.core.result.Result` is used,
135 but objects which inherit from this class can be given providing
136 additional methods.
137 soft_init: bool, optional
138 Switch to enable a soft initialization that prevents the likelihood
139 from being tested before running the sampler. This is relevant when
140 using custom likelihoods that must NOT be initialized on the main thread
141 when using multiprocessing, e.g. when using tensorflow in the likelihood.
142 **kwargs: dict
143 Additional keyword arguments
145 Attributes
146 ==========
147 likelihood: likelihood.Likelihood
148 A object with a log_l method
149 priors: bilby.core.prior.PriorDict
150 Priors to be used in the search.
151 This has attributes for each parameter to be sampled.
152 external_sampler: Module
153 An external module containing an implementation of a sampler.
154 outdir: str
155 Name of the output directory
156 label: str
157 Naming scheme of the output files
158 use_ratio: bool
159 Switch to set whether or not you want to use the log-likelihood ratio
160 or just the log-likelihood
161 plot: bool
162 Switch to set whether or not you want to create traceplots
163 skip_import_verification: bool
164 Skips the check if the sampler is installed if true. This is
165 only advisable for testing environments
166 result: bilby.core.result.Result
167 Container for the results of the sampling run
168 exit_code: int
169 System exit code to return on interrupt
170 kwargs: dict
171 Dictionary of keyword arguments that can be used in the external sampler
172 hard_exit: bool
173 Whether the implemented sampler exits hard (:code:`os._exit` rather
174 than :code:`sys.exit`). The latter can be escaped as :code:`SystemExit`.
175 The former cannot.
176 sampler_name : str
177 Name of the sampler. This is used when creating the output directory for
178 the sampler.
179 abbreviation : str
180 Abbreviated name of the sampler. Does not have to be specified in child
181 classes. If set to a value other than :code:`None`, this will be used
182 instead of :code:`sampler_name` when creating the output directory.
184 Raises
185 ======
186 TypeError:
187 If external_sampler is neither a string nor an instance of this class
188 If not all likelihood.parameters have been defined
189 ImportError:
190 If the external_sampler string does not refer to a sampler that is
191 installed on this system
192 AttributeError:
193 If some of the priors can't be sampled
195 """
197 sampler_name = "sampler"
198 abbreviation = None
199 default_kwargs = dict()
200 npool_equiv_kwargs = [
201 "npool",
202 "queue_size",
203 "threads",
204 "nthreads",
205 "cores",
206 "n_pool",
207 ]
208 sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"]
209 hard_exit = False
210 sampling_seed_key = None
211 """Name of keyword argument for setting the sampling for the specific sampler.
212 If a specific sampler does not have a sampling seed option, then it should be
213 left as None.
214 """
215 check_point_equiv_kwargs = ["check_point_deltaT", "check_point_delta_t"]
217 def __init__(
218 self,
219 likelihood,
220 priors,
221 outdir="outdir",
222 label="label",
223 use_ratio=False,
224 plot=False,
225 skip_import_verification=False,
226 injection_parameters=None,
227 meta_data=None,
228 result_class=None,
229 likelihood_benchmark=False,
230 soft_init=False,
231 exit_code=130,
232 npool=1,
233 **kwargs,
234 ):
235 self.likelihood = likelihood
236 if isinstance(priors, PriorDict):
237 self.priors = priors
238 else:
239 self.priors = PriorDict(priors)
240 self.label = label
241 self.outdir = outdir
242 self.injection_parameters = injection_parameters
243 self.meta_data = meta_data
244 self.use_ratio = use_ratio
245 self._npool = npool
246 if not skip_import_verification:
247 self._verify_external_sampler()
248 self.external_sampler_function = None
249 self.plot = plot
250 self.likelihood_benchmark = likelihood_benchmark
252 self._search_parameter_keys = list()
253 self._fixed_parameter_keys = list()
254 self._constraint_parameter_keys = list()
255 self._initialise_parameters()
256 self._log_information_about_priors_and_likelihood()
258 self.exit_code = exit_code
260 self._log_likelihood_eval_time = np.nan
261 if not soft_init:
262 self._verify_parameters()
263 self._log_likelihood_eval_time = self._time_likelihood()
264 self._verify_use_ratio()
266 self.kwargs = kwargs
268 self._check_cached_result(result_class)
270 self._log_summary_for_sampler()
272 self.result = self._initialise_result(result_class)
273 self.likelihood_count = None
274 if self.likelihood_benchmark:
275 self.likelihood_count = Counter()
277 @property
278 def search_parameter_keys(self):
279 """list: List of parameter keys that are being sampled"""
280 return self._search_parameter_keys
282 @property
283 def fixed_parameter_keys(self):
284 """list: List of parameter keys that are not being sampled"""
285 return self._fixed_parameter_keys
287 @property
288 def constraint_parameter_keys(self):
289 """list: List of parameters providing prior constraints"""
290 return self._constraint_parameter_keys
292 @property
293 def ndim(self):
294 """int: Number of dimensions of the search parameter space"""
295 return len(self._search_parameter_keys)
297 @property
298 def kwargs(self):
299 """dict: Container for the kwargs. Has more sophisticated logic in subclasses"""
300 return self._kwargs
302 @kwargs.setter
303 def kwargs(self, kwargs):
304 self._kwargs = self.default_kwargs.copy()
305 self._translate_kwargs(kwargs)
306 self._kwargs.update(kwargs)
307 self._verify_kwargs_against_default_kwargs()
309 def _translate_kwargs(self, kwargs):
310 """Translate keyword arguments.
312 Default only translates the sampling seed if the sampler has
313 :code:`sampling_seed_key` set.
314 """
315 if self.sampling_seed_key and self.sampling_seed_key not in kwargs:
316 for equiv in self.sampling_seed_equiv_kwargs:
317 if equiv in kwargs:
318 kwargs[self.sampling_seed_key] = kwargs.pop(equiv)
319 set_seed(kwargs[self.sampling_seed_key])
320 return kwargs
322 @property
323 def external_sampler_name(self):
324 return self.__class__.__name__.lower()
326 def _verify_external_sampler(self):
327 external_sampler_name = self.external_sampler_name
328 try:
329 __import__(external_sampler_name)
330 except (ImportError, SystemExit):
331 raise SamplerNotInstalledError(
332 f"Sampler {external_sampler_name} is not installed on this system"
333 )
335 def _verify_kwargs_against_default_kwargs(self):
336 """
337 Check if the kwargs are contained in the list of available arguments
338 of the external sampler.
339 """
340 args = self.default_kwargs
341 bad_keys = []
342 for user_input in self.kwargs.keys():
343 if user_input not in args:
344 logger.warning(
345 f"Supplied argument '{user_input}' not an argument of '{self.__class__.__name__}', removing."
346 )
347 bad_keys.append(user_input)
348 for key in bad_keys:
349 self.kwargs.pop(key)
351 def _initialise_parameters(self):
352 """
353 Go through the list of priors and add keys to the fixed and search
354 parameter key list depending on whether
355 the respective parameter is fixed.
356 """
357 for key in self.priors:
358 if (
359 isinstance(self.priors[key], Prior)
360 and self.priors[key].is_fixed is False
361 ):
362 self._search_parameter_keys.append(key)
363 elif isinstance(self.priors[key], Constraint):
364 self._constraint_parameter_keys.append(key)
365 elif isinstance(self.priors[key], DeltaFunction):
366 self.likelihood.parameters[key] = self.priors[key].sample()
367 self._fixed_parameter_keys.append(key)
369 def _log_information_about_priors_and_likelihood(self):
370 logger.info("Analysis priors:")
371 for key in self._search_parameter_keys + self._constraint_parameter_keys:
372 logger.info(f"{key}={self.priors[key]}")
373 for key in self._fixed_parameter_keys:
374 logger.info(f"{key}={self.priors[key].peak}")
375 logger.info(f"Analysis likelihood class: {self.likelihood.__class__}")
376 logger.info(
377 f"Analysis likelihood noise evidence: {self.likelihood.noise_log_likelihood()}"
378 )
380 def _initialise_result(self, result_class):
381 """
382 Returns
383 =======
384 bilby.core.result.Result: An initial template for the result
386 """
387 result_kwargs = dict(
388 label=self.label,
389 outdir=self.outdir,
390 sampler=self.__class__.__name__.lower(),
391 search_parameter_keys=self._search_parameter_keys,
392 fixed_parameter_keys=self._fixed_parameter_keys,
393 constraint_parameter_keys=self._constraint_parameter_keys,
394 priors=self.priors,
395 meta_data=self.meta_data,
396 injection_parameters=self.injection_parameters,
397 sampler_kwargs=self.kwargs,
398 use_ratio=self.use_ratio,
399 )
401 if result_class is None:
402 result = Result(**result_kwargs)
403 elif issubclass(result_class, Result):
404 result = result_class(**result_kwargs)
405 else:
406 raise ValueError(f"Input result_class={result_class} not understood")
408 return result
410 def _verify_parameters(self):
411 """Evaluate a set of parameters drawn from the prior
413 Tests if the likelihood evaluation passes
415 Raises
416 ======
417 TypeError
418 Likelihood can't be evaluated.
420 """
422 if self.priors.test_has_redundant_keys():
423 raise IllegalSamplingSetError(
424 "Your sampling set contains redundant parameters."
425 )
427 theta = self.priors.sample_subset_constrained_as_array(
428 self.search_parameter_keys, size=1
429 )[:, 0]
430 try:
431 self.log_likelihood(theta)
432 except TypeError as e:
433 raise TypeError(
434 f"Likelihood evaluation failed with message: \n'{e}'\n"
435 f"Have you specified all the parameters:\n{self.likelihood.parameters}"
436 )
438 def _time_likelihood(self, n_evaluations=100):
439 """Times the likelihood evaluation and print an info message
441 Parameters
442 ==========
443 n_evaluations: int
444 The number of evaluations to estimate the evaluation time from
446 Returns
447 =======
448 log_likelihood_eval_time: float
449 The time (in s) it took for one likelihood evaluation
450 """
452 t1 = datetime.datetime.now()
453 for _ in range(n_evaluations):
454 theta = self.priors.sample_subset_constrained_as_array(
455 self._search_parameter_keys, size=1
456 )[:, 0]
457 self.log_likelihood(theta)
458 total_time = (datetime.datetime.now() - t1).total_seconds()
459 log_likelihood_eval_time = total_time / n_evaluations
461 if log_likelihood_eval_time == 0:
462 log_likelihood_eval_time = np.nan
463 logger.info("Unable to measure single likelihood time")
464 else:
465 logger.info(
466 f"Single likelihood evaluation took {log_likelihood_eval_time:.3e} s"
467 )
468 return log_likelihood_eval_time
470 def _verify_use_ratio(self):
471 """
472 Checks if use_ratio is set. Prints a warning if use_ratio is set but
473 not properly implemented.
474 """
475 try:
476 self.priors.sample_subset(self.search_parameter_keys)
477 except (KeyError, AttributeError):
478 logger.error(
479 f"Cannot sample from priors with keys: {self.search_parameter_keys}."
480 )
481 raise
482 if self.use_ratio is False:
483 logger.debug("use_ratio set to False")
484 return
486 ratio_is_nan = np.isnan(self.likelihood.log_likelihood_ratio())
488 if self.use_ratio is True and ratio_is_nan:
489 logger.warning(
490 "You have requested to use the loglikelihood_ratio, but it "
491 " returns a NaN"
492 )
493 elif self.use_ratio is None and not ratio_is_nan:
494 logger.debug("use_ratio not spec. but gives valid answer, setting True")
495 self.use_ratio = True
497 def prior_transform(self, theta):
498 """Prior transform method that is passed into the external sampler.
500 Parameters
501 ==========
502 theta: list
503 List of sampled values on a unit interval
505 Returns
506 =======
507 list: Properly rescaled sampled values
508 """
509 return self.priors.rescale(self._search_parameter_keys, theta)
511 def log_prior(self, theta):
512 """
514 Parameters
515 ==========
516 theta: list
517 List of sampled values on a unit interval
519 Returns
520 =======
521 float: Joint ln prior probability of theta
523 """
524 params = {key: t for key, t in zip(self._search_parameter_keys, theta)}
525 return self.priors.ln_prob(params)
527 def log_likelihood(self, theta):
528 """
530 Parameters
531 ==========
532 theta: list
533 List of values for the likelihood parameters
535 Returns
536 =======
537 float: Log-likelihood or log-likelihood-ratio given the current
538 likelihood.parameter values
540 """
541 if self.likelihood_benchmark:
542 try:
543 self.likelihood_count.increment()
544 except AttributeError:
545 pass
546 params = {key: t for key, t in zip(self._search_parameter_keys, theta)}
547 self.likelihood.parameters.update(params)
548 if self.use_ratio:
549 return self.likelihood.log_likelihood_ratio()
550 else:
551 return self.likelihood.log_likelihood()
553 def get_random_draw_from_prior(self):
554 """Get a random draw from the prior distribution
556 Returns
557 =======
558 draw: array_like
559 An ndim-length array of values drawn from the prior. Parameters
560 with delta-function (or fixed) priors are not returned
562 """
563 new_sample = self.priors.sample()
564 draw = np.array(list(new_sample[key] for key in self._search_parameter_keys))
565 self.check_draw(draw)
566 return draw
568 def get_initial_points_from_prior(self, npoints=1):
569 """Method to draw a set of live points from the prior
571 This iterates over draws from the prior until all the samples have a
572 finite prior and likelihood (relevant for constrained priors).
574 Parameters
575 ==========
576 npoints: int
577 The number of values to return
579 Returns
580 =======
581 unit_cube, parameters, likelihood: tuple of array_like
582 unit_cube (nlive, ndim) is an array of the prior samples from the
583 unit cube, parameters (nlive, ndim) is the unit_cube array
584 transformed to the target space, while likelihood (nlive) are the
585 likelihood evaluations.
587 """
588 from ..utils.random import rng
590 logger.info("Generating initial points from the prior")
591 unit_cube = []
592 parameters = []
593 likelihood = []
594 while len(unit_cube) < npoints:
595 unit = rng.uniform(0, 1, self.ndim)
596 theta = self.prior_transform(unit)
597 if self.check_draw(theta, warning=False):
598 unit_cube.append(unit)
599 parameters.append(theta)
600 likelihood.append(self.log_likelihood(theta))
602 return np.array(unit_cube), np.array(parameters), np.array(likelihood)
604 def check_draw(self, theta, warning=True):
605 """
606 Checks if the draw will generate an infinite prior or likelihood
608 Also catches the output of `numpy.nan_to_num`.
610 Parameters
611 ==========
612 theta: array_like
613 Parameter values at which to evaluate likelihood
614 warning: bool
615 Whether or not to print a warning
617 Returns
618 =======
619 bool, cube (nlive,
620 True if the likelihood and prior are finite, false otherwise
622 """
623 log_p = self.log_prior(theta)
624 log_l = self.log_likelihood(theta)
625 return self._check_bad_value(
626 val=log_p, warning=warning, theta=theta, label="prior"
627 ) and self._check_bad_value(
628 val=log_l, warning=warning, theta=theta, label="likelihood"
629 )
631 @staticmethod
632 def _check_bad_value(val, warning, theta, label):
633 val = np.abs(val)
634 bad_values = [np.inf, np.nan_to_num(np.inf)]
635 if val in bad_values or np.isnan(val):
636 if warning:
637 logger.warning(f"Prior draw {theta} has inf {label}")
638 return False
639 return True
641 def run_sampler(self):
642 """A template method to run in subclasses"""
643 pass
645 def _run_test(self):
646 """
647 TODO: Implement this method
648 Raises
649 =======
650 ValueError: in any case
651 """
652 raise ValueError("Method not yet implemented")
654 def _check_cached_result(self, result_class=None):
655 """Check if the cached data file exists and can be used"""
657 if command_line_args.clean:
658 logger.debug("Command line argument clean given, forcing rerun")
659 self.cached_result = None
660 return
662 try:
663 self.cached_result = read_in_result(
664 outdir=self.outdir, label=self.label, result_class=result_class
665 )
666 except IOError:
667 self.cached_result = None
669 if command_line_args.use_cached:
670 logger.debug("Command line argument cached given, no cache check performed")
671 return
673 logger.debug("Checking cached data")
674 if self.cached_result:
675 check_keys = ["search_parameter_keys", "fixed_parameter_keys"]
676 use_cache = True
677 for key in check_keys:
678 if (
679 self.cached_result._check_attribute_match_to_other_object(key, self)
680 is False
681 ):
682 logger.debug(f"Cached value {key} is unmatched")
683 use_cache = False
684 try:
685 # Recursive check the dictionaries allowing for numpy arrays
686 np.testing.assert_equal(
687 self.meta_data["likelihood"],
688 self.cached_result.meta_data["likelihood"],
689 )
690 except AssertionError:
691 use_cache = False
692 if use_cache is False:
693 self.cached_result = None
695 def _log_summary_for_sampler(self):
696 """Print a summary of the sampler used and its kwargs"""
697 if self.cached_result is None:
698 kwargs_print = self.kwargs.copy()
699 for k in kwargs_print:
700 if isinstance(kwargs_print[k], (list, np.ndarray)):
701 array_repr = np.array(kwargs_print[k])
702 if array_repr.size > 10:
703 kwargs_print[k] = f"array_like, shape={array_repr.shape}"
704 elif isinstance(kwargs_print[k], DataFrame):
705 kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}"
706 logger.info(
707 f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}"
708 )
710 def calc_likelihood_count(self):
711 if self.likelihood_benchmark:
712 self.result.num_likelihood_evaluations = self.likelihood_count.value
713 else:
714 return None
716 @property
717 def npool(self):
718 for key in self.npool_equiv_kwargs:
719 if key in self.kwargs:
720 return self.kwargs[key]
721 return self._npool
723 def _log_interruption(self, signum=None):
724 if signum == 14:
725 logger.info(
726 f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}"
727 )
728 else:
729 logger.info(
730 f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}"
731 )
733 def write_current_state_and_exit(self, signum=None, frame=None):
734 """
735 Make sure that if a pool of jobs is running only the parent tries to
736 checkpoint and exit. Only the parent has a 'pool' attribute.
738 For samplers that must hard exit (typically due to non-Python process)
739 use :code:`os._exit` that cannot be excepted. Other samplers exiting
740 can be caught as a :code:`SystemExit`.
741 """
742 if self.npool in (1, None) or getattr(self, "pool", None) is not None:
743 self._log_interruption(signum=signum)
744 self.write_current_state()
745 self._close_pool()
746 if self.hard_exit:
747 os._exit(self.exit_code)
748 else:
749 sys.exit(self.exit_code)
751 def _close_pool(self):
752 if getattr(self, "pool", None) is not None:
753 logger.info("Starting to close worker pool.")
754 self.pool.close()
755 self.pool.join()
756 self.pool = None
757 self.kwargs["pool"] = self.pool
758 logger.info("Finished closing worker pool.")
760 def _setup_pool(self):
761 if self.kwargs.get("pool", None) is not None:
762 logger.info("Using user defined pool.")
763 self.pool = self.kwargs["pool"]
764 elif self.npool is not None and self.npool > 1:
765 logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
766 import multiprocessing
768 self.pool = multiprocessing.Pool(
769 processes=self.npool,
770 initializer=_initialize_global_variables,
771 initargs=(
772 self.likelihood,
773 self.priors,
774 self._search_parameter_keys,
775 self.use_ratio,
776 ),
777 )
778 else:
779 self.pool = None
780 _initialize_global_variables(
781 likelihood=self.likelihood,
782 priors=self.priors,
783 search_parameter_keys=self._search_parameter_keys,
784 use_ratio=self.use_ratio,
785 )
786 self.kwargs["pool"] = self.pool
788 def write_current_state(self):
789 raise NotImplementedError()
791 @classmethod
792 def get_expected_outputs(cls, outdir=None, label=None):
793 """Get lists of the expected outputs directories and files.
795 These are used by :code:`bilby_pipe` when transferring files via HTCondor.
796 Both can be empty. Defaults to a single directory:
797 :code:`"{outdir}/{name}_{label}/"`, where :code:`name`
798 is :code:`abbreviation` if it is defined for the sampler class, otherwise
799 it defaults to :code:`sampler_name`.
801 Parameters
802 ----------
803 outdir : str
804 The output directory.
805 label : str
806 The label for the run.
808 Returns
809 -------
810 list
811 List of file names.
812 list
813 List of directory names.
814 """
815 name = cls.abbreviation or cls.sampler_name
816 dirname = os.path.join(outdir, f"{name}_{label}", "")
817 return [], [dirname]
820class NestedSampler(Sampler):
821 sampler_name = "nested_sampler"
822 npoints_equiv_kwargs = [
823 "nlive",
824 "nlives",
825 "n_live_points",
826 "npoints",
827 "npoint",
828 "Nlive",
829 "num_live_points",
830 "num_particles",
831 ]
832 walks_equiv_kwargs = ["walks", "steps", "nmcmc"]
834 @staticmethod
835 def reorder_loglikelihoods(
836 unsorted_loglikelihoods, unsorted_samples, sorted_samples
837 ):
838 """Reorders the stored log-likelihood after they have been reweighted
840 This creates a sorting index by matching the reweights `result.samples`
841 against the raw samples, then uses this index to sort the
842 loglikelihoods
844 Parameters
845 ==========
846 sorted_samples, unsorted_samples: array-like
847 Sorted and unsorted values of the samples. These should be of the
848 same shape and contain the same sample values, but in different
849 orders
850 unsorted_loglikelihoods: array-like
851 The loglikelihoods corresponding to the unsorted_samples
853 Returns
854 =======
855 sorted_loglikelihoods: array-like
856 The loglikelihoods reordered to match that of the sorted_samples
859 """
861 idxs = []
862 for ii in range(len(unsorted_loglikelihoods)):
863 idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0]
864 if len(idx) > 1:
865 logger.warning(
866 "Multiple likelihood matches found between sorted and "
867 "unsorted samples. Taking the first match."
868 )
869 idxs.append(idx[0])
870 return unsorted_loglikelihoods[idxs]
872 def log_likelihood(self, theta):
873 """
874 Since some nested samplers don't call the log_prior method, evaluate
875 the prior constraint here.
877 Parameters
878 ==========
879 theta: array_like
880 Parameter values at which to evaluate likelihood
882 Returns
883 =======
884 float: log_likelihood
885 """
886 if self.priors.evaluate_constraints(
887 {key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)}
888 ):
889 return Sampler.log_likelihood(self, theta)
890 else:
891 return np.nan_to_num(-np.inf)
894class MCMCSampler(Sampler):
895 sampler_name = "mcmc_sampler"
896 nwalkers_equiv_kwargs = ["nwalker", "nwalkers", "draws", "Niter"]
897 nburn_equiv_kwargs = ["burn", "nburn"]
899 def print_nburn_logging_info(self):
900 """Prints logging info as to how nburn was calculated"""
901 if type(self.nburn) in [float, int]:
902 logger.info(f"Discarding {self.nburn} steps for burn-in")
903 elif self.result.max_autocorrelation_time is None:
904 logger.info(
905 f"Autocorrelation time not calculated, discarding "
906 f"{self.nburn} steps for burn-in"
907 )
908 else:
909 logger.info(
910 f"Discarding {self.nburn} steps for burn-in, estimated from autocorr"
911 )
913 def calculate_autocorrelation(self, samples, c=3):
914 """Uses the `emcee.autocorr` module to estimate the autocorrelation
916 Parameters
917 ==========
918 samples: array_like
919 A chain of samples.
920 c: float
921 The minimum number of autocorrelation times needed to trust the
922 estimate (default: `3`). See `emcee.autocorr.integrated_time`.
923 """
924 import emcee
926 try:
927 self.result.max_autocorrelation_time = int(
928 np.max(emcee.autocorr.integrated_time(samples, c=c))
929 )
930 logger.info(f"Max autocorr time = {self.result.max_autocorrelation_time}")
931 except emcee.autocorr.AutocorrError as e:
932 self.result.max_autocorrelation_time = None
933 logger.info(f"Unable to calculate autocorr time: {e}")
936class _TemporaryFileSamplerMixin:
937 """
938 A mixin class to handle storing sampler intermediate products in a temporary
939 location. See, e.g., `this SO <https://stackoverflow.com/a/547714>` for a
940 basic background on mixins.
942 This class makes sure that any subclasses can seamlessly use the temporary
943 file functionality.
944 """
946 short_name = ""
948 def __init__(self, temporary_directory, **kwargs):
949 super(_TemporaryFileSamplerMixin, self).__init__(**kwargs)
950 try:
951 from mpi4py import MPI
953 using_mpi = MPI.COMM_WORLD.Get_size() > 1
954 except ImportError:
955 using_mpi = False
957 if using_mpi and temporary_directory:
958 logger.info(
959 "Temporary directory incompatible with MPI, "
960 "will run in original directory"
961 )
962 self.use_temporary_directory = temporary_directory and not using_mpi
963 self._outputfiles_basename = None
964 self._temporary_outputfiles_basename = None
966 def _check_and_load_sampling_time_file(self):
967 if os.path.exists(self.time_file_path):
968 with open(self.time_file_path, "r") as time_file:
969 self.total_sampling_time = float(time_file.readline())
970 else:
971 self.total_sampling_time = 0
973 def _calculate_and_save_sampling_time(self):
974 current_time = time.time()
975 new_sampling_time = current_time - self.start_time
976 self.total_sampling_time += new_sampling_time
978 with open(self.time_file_path, "w") as time_file:
979 time_file.write(str(self.total_sampling_time))
981 self.start_time = current_time
983 def _clean_up_run_directory(self):
984 if self.use_temporary_directory:
985 self._move_temporary_directory_to_proper_path()
986 self.kwargs["outputfiles_basename"] = self.outputfiles_basename
988 @property
989 def outputfiles_basename(self):
990 return self._outputfiles_basename
992 @outputfiles_basename.setter
993 def outputfiles_basename(self, outputfiles_basename):
994 if outputfiles_basename is None:
995 outputfiles_basename = f"{self.outdir}/{self.short_name}_{self.label}/"
996 if not outputfiles_basename.endswith("/"):
997 outputfiles_basename += "/"
998 check_directory_exists_and_if_not_mkdir(self.outdir)
999 self._outputfiles_basename = outputfiles_basename
1001 @property
1002 def temporary_outputfiles_basename(self):
1003 return self._temporary_outputfiles_basename
1005 @temporary_outputfiles_basename.setter
1006 def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
1007 if not temporary_outputfiles_basename.endswith("/"):
1008 temporary_outputfiles_basename += "/"
1009 self._temporary_outputfiles_basename = temporary_outputfiles_basename
1010 if os.path.exists(self.outputfiles_basename):
1011 shutil.copytree(
1012 self.outputfiles_basename, self.temporary_outputfiles_basename
1013 )
1015 def write_current_state(self):
1016 self._calculate_and_save_sampling_time()
1017 if self.use_temporary_directory:
1018 self._move_temporary_directory_to_proper_path()
1020 def _move_temporary_directory_to_proper_path(self):
1021 """
1022 Move the temporary back to the proper path
1024 Anything in the proper path at this point is removed including links
1025 """
1026 self._copy_temporary_directory_contents_to_proper_path()
1027 shutil.rmtree(self.temporary_outputfiles_basename)
1029 def _copy_temporary_directory_contents_to_proper_path(self):
1030 """
1031 Copy the temporary back to the proper path.
1032 Do not delete the temporary directory.
1033 """
1034 logger.info(
1035 f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}"
1036 )
1037 outputfiles_basename_stripped = self.outputfiles_basename.rstrip("/")
1038 shutil.copytree(
1039 self.temporary_outputfiles_basename,
1040 outputfiles_basename_stripped,
1041 dirs_exist_ok=True,
1042 )
1044 def _setup_run_directory(self):
1045 """
1046 If using a temporary directory, the output directory is moved to the
1047 temporary directory.
1048 Used for Dnest4, Pymultinest, and Ultranest.
1049 """
1050 check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
1051 if self.use_temporary_directory:
1052 temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
1053 self.temporary_outputfiles_basename = temporary_outputfiles_basename
1055 if os.path.exists(self.outputfiles_basename):
1056 shutil.copytree(
1057 self.outputfiles_basename,
1058 self.temporary_outputfiles_basename,
1059 dirs_exist_ok=True,
1060 )
1061 check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
1063 self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
1064 logger.info(f"Using temporary file {temporary_outputfiles_basename}")
1065 else:
1066 self.kwargs["outputfiles_basename"] = self.outputfiles_basename
1067 logger.info(f"Using output file {self.outputfiles_basename}")
1068 self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
1071class Error(Exception):
1072 """Base class for all exceptions raised by this module"""
1075class SamplerError(Error):
1076 """Base class for Error related to samplers in this module"""
1079class ResumeError(Error):
1080 """Class for errors arising from resuming runs"""
1083class SamplerNotInstalledError(SamplerError):
1084 """Base class for Error raised by not installed samplers"""
1087class IllegalSamplingSetError(Error):
1088 """Class for illegal sets of sampling parameters"""
1091class SamplingMarginalisedParameterError(IllegalSamplingSetError):
1092 """Class for errors that occur when sampling over marginalized parameters"""