Coverage for bilby/core/sampler/dynesty.py: 73%
524 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import inspect
3import os
4import sys
5import time
6import warnings
8import numpy as np
9from pandas import DataFrame
11from ..result import rejection_sample
12from ..utils import (
13 check_directory_exists_and_if_not_mkdir,
14 latex_plot_format,
15 logger,
16 safe_file_dump,
17)
18from .base_sampler import NestedSampler, Sampler, _SamplingContainer, signal_wrapper
21def _set_sampling_kwargs(args):
22 nact, maxmcmc, proposals, naccept = args
23 _SamplingContainer.nact = nact
24 _SamplingContainer.maxmcmc = maxmcmc
25 _SamplingContainer.proposals = proposals
26 _SamplingContainer.naccept = naccept
29def _prior_transform_wrapper(theta):
30 """Wrapper to the prior transformation. Needed for multiprocessing."""
31 from .base_sampler import _sampling_convenience_dump
33 return _sampling_convenience_dump.priors.rescale(
34 _sampling_convenience_dump.search_parameter_keys, theta
35 )
38def _log_likelihood_wrapper(theta):
39 """Wrapper to the log likelihood. Needed for multiprocessing."""
40 from .base_sampler import _sampling_convenience_dump
42 if _sampling_convenience_dump.priors.evaluate_constraints(
43 {
44 key: theta[ii]
45 for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys)
46 }
47 ):
48 params = {
49 key: t
50 for key, t in zip(_sampling_convenience_dump.search_parameter_keys, theta)
51 }
52 _sampling_convenience_dump.likelihood.parameters.update(params)
53 if _sampling_convenience_dump.use_ratio:
54 return _sampling_convenience_dump.likelihood.log_likelihood_ratio()
55 else:
56 return _sampling_convenience_dump.likelihood.log_likelihood()
57 else:
58 return np.nan_to_num(-np.inf)
61class Dynesty(NestedSampler):
62 """
63 bilby wrapper of `dynesty.NestedSampler`
64 (https://dynesty.readthedocs.io/en/latest/)
66 All positional and keyword arguments (i.e., the args and kwargs) passed to
67 `run_sampler` will be propagated to `dynesty.NestedSampler`, see
68 documentation for that class for further help. Under Other Parameters below,
69 we list commonly used kwargs and the Bilby defaults.
71 Parameters
72 ==========
73 likelihood: likelihood.Likelihood
74 A object with a log_l method
75 priors: bilby.core.prior.PriorDict, dict
76 Priors to be used in the search.
77 This has attributes for each parameter to be sampled.
78 outdir: str, optional
79 Name of the output directory
80 label: str, optional
81 Naming scheme of the output files
82 use_ratio: bool, optional
83 Switch to set whether or not you want to use the log-likelihood ratio
84 or just the log-likelihood
85 plot: bool, optional
86 Switch to set whether or not you want to create traceplots
87 skip_import_verification: bool
88 Skips the check if the sampler is installed if true. This is
89 only advisable for testing environments
90 print_method: str ('tqdm')
91 The method to use for printing. The options are:
92 - 'tqdm': use a `tqdm` `pbar`, this is the default.
93 - 'interval-$TIME': print to `stdout` every `$TIME` seconds,
94 e.g., 'interval-10' prints every ten seconds, this does not print every iteration
95 - else: print to `stdout` at every iteration
96 exit_code: int
97 The code which the same exits on if it hasn't finished sampling
98 check_point: bool,
99 If true, use check pointing.
100 check_point_plot: bool,
101 If true, generate a trace plot along with the check-point
102 check_point_delta_t: float (600)
103 The minimum checkpoint period (in seconds). Should the run be
104 interrupted, it can be resumed from the last checkpoint.
105 n_check_point: int, optional (None)
106 The number of steps to take before checking whether to check_point.
107 resume: bool
108 If true, resume run from checkpoint (if available)
109 maxmcmc: int (5000)
110 The maximum length of the MCMC exploration to find a new point
111 nact: int (2)
112 The number of autocorrelation lengths for MCMC exploration.
113 For use with the :code:`act-walk` and :code:`rwalk` sample methods.
114 See the dynesty guide in the Bilby docs for more details.
115 naccept: int (60)
116 The expected number of accepted steps for MCMC exploration when using
117 the :code:`acceptance-walk` sampling method.
118 rejection_sample_posterior: bool (True)
119 Whether to form the posterior by rejection sampling the nested samples.
120 If False, the nested samples are resampled with repetition. This was
121 the default behaviour in :code:`Bilby<=1.4.1` and leads to
122 non-independent samples being produced.
123 proposals: iterable (None)
124 The proposal methods to use during MCMC. This can be some combination
125 of :code:`"diff", "volumetric"`. See the dynesty guide in the Bilby docs
126 for more details. default=:code:`["diff"]`.
127 rstate: numpy.random.Generator (None)
128 Instance of a numpy random generator for generating random numbers.
129 Also see :code:`seed` in 'Other Parameters'.
131 Other Parameters
132 ================
133 nlive: int, (1000)
134 The number of live points, note this can also equivalently be given as
135 one of [nlive, nlives, n_live_points, npoints]
136 bound: {'live', 'live-multi', 'none', 'single', 'multi', 'balls', 'cubes'}, ('live')
137 Method used to select new points
138 sample: {'act-walk', 'acceptance-walk', 'unif', 'rwalk', 'slice',
139 'rslice', 'hslice', 'rwalk_dynesty'}, ('act-walk')
140 Method used to sample uniformly within the likelihood constraints,
141 conditioned on the provided bounds
142 walks: int (100)
143 Number of walks taken if using the dynesty implemented sample methods
144 Note that the default `walks` in dynesty itself is 25, although using
145 `ndim * 10` can be a reasonable rule of thumb for new problems.
146 For :code:`sample="act-walk"` and :code:`sample="rwalk"` this parameter
147 has no impact on the sampling.
148 dlogz: float, (0.1)
149 Stopping criteria
150 seed: int (None)
151 Use to seed the random number generator if :code:`rstate` is not
152 specified.
153 """
155 sampler_name = "dynesty"
156 sampling_seed_key = "seed"
158 @property
159 def _dynesty_init_kwargs(self):
160 params = inspect.signature(self.sampler_init).parameters
161 kwargs = {
162 key: param.default
163 for key, param in params.items()
164 if param.default != param.empty
165 }
166 kwargs["sample"] = "act-walk"
167 kwargs["bound"] = "live"
168 kwargs["update_interval"] = 600
169 kwargs["facc"] = 0.2
170 return kwargs
172 @property
173 def _dynesty_sampler_kwargs(self):
174 params = inspect.signature(self.sampler_class.run_nested).parameters
175 kwargs = {
176 key: param.default
177 for key, param in params.items()
178 if param.default != param.empty
179 }
180 kwargs["save_bounds"] = False
181 if "dlogz" in kwargs:
182 kwargs["dlogz"] = 0.1
183 return kwargs
185 @property
186 def default_kwargs(self):
187 kwargs = self._dynesty_init_kwargs
188 kwargs.update(self._dynesty_sampler_kwargs)
189 kwargs["seed"] = None
190 return kwargs
192 def __init__(
193 self,
194 likelihood,
195 priors,
196 outdir="outdir",
197 label="label",
198 use_ratio=False,
199 plot=False,
200 skip_import_verification=False,
201 check_point=True,
202 check_point_plot=True,
203 n_check_point=None,
204 check_point_delta_t=600,
205 resume=True,
206 nestcheck=False,
207 exit_code=130,
208 print_method="tqdm",
209 maxmcmc=5000,
210 nact=2,
211 naccept=60,
212 rejection_sample_posterior=True,
213 proposals=None,
214 **kwargs,
215 ):
216 self.nact = nact
217 self.naccept = naccept
218 self.maxmcmc = maxmcmc
219 self.proposals = proposals
220 self.print_method = print_method
221 self._translate_kwargs(kwargs)
222 super(Dynesty, self).__init__(
223 likelihood=likelihood,
224 priors=priors,
225 outdir=outdir,
226 label=label,
227 use_ratio=use_ratio,
228 plot=plot,
229 skip_import_verification=skip_import_verification,
230 exit_code=exit_code,
231 **kwargs,
232 )
233 self.n_check_point = n_check_point
234 self.check_point = check_point
235 self.check_point_plot = check_point_plot
236 self.resume = resume
237 self.rejection_sample_posterior = rejection_sample_posterior
238 self._apply_dynesty_boundaries("periodic")
239 self._apply_dynesty_boundaries("reflective")
241 self.nestcheck = nestcheck
243 if self.n_check_point is None:
244 self.n_check_point = (
245 10
246 if np.isnan(self._log_likelihood_eval_time)
247 else max(
248 int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10
249 )
250 )
251 self.check_point_delta_t = check_point_delta_t
252 logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s")
254 self.resume_file = f"{self.outdir}/{self.label}_resume.pickle"
255 self.sampling_time = datetime.timedelta()
256 self.pbar = None
258 @property
259 def sampler_function_kwargs(self):
260 return {key: self.kwargs[key] for key in self._dynesty_sampler_kwargs}
262 @property
263 def sampler_init_kwargs(self):
264 return {key: self.kwargs[key] for key in self._dynesty_init_kwargs}
266 def _translate_kwargs(self, kwargs):
267 kwargs = super()._translate_kwargs(kwargs)
268 if "nlive" not in kwargs:
269 for equiv in self.npoints_equiv_kwargs:
270 if equiv in kwargs:
271 kwargs["nlive"] = kwargs.pop(equiv)
272 if "print_progress" not in kwargs:
273 if "verbose" in kwargs:
274 kwargs["print_progress"] = kwargs.pop("verbose")
275 if "walks" not in kwargs:
276 for equiv in self.walks_equiv_kwargs:
277 if equiv in kwargs:
278 kwargs["walks"] = kwargs.pop(equiv)
279 if "queue_size" not in kwargs:
280 for equiv in self.npool_equiv_kwargs:
281 if equiv in kwargs:
282 kwargs["queue_size"] = kwargs.pop(equiv)
283 if "seed" in kwargs:
284 seed = kwargs.get("seed")
285 if "rstate" not in kwargs:
286 kwargs["rstate"] = np.random.default_rng(seed)
287 else:
288 logger.warning(
289 "Kwargs contain both 'rstate' and 'seed', ignoring 'seed'."
290 )
292 def _verify_kwargs_against_default_kwargs(self):
293 if not self.kwargs["walks"]:
294 self.kwargs["walks"] = 100
295 if self.kwargs["print_func"] is None:
296 self.kwargs["print_func"] = self._print_func
297 if "interval" in self.print_method:
298 self._last_print_time = datetime.datetime.now()
299 self._print_interval = datetime.timedelta(
300 seconds=float(self.print_method.split("-")[1])
301 )
302 Sampler._verify_kwargs_against_default_kwargs(self)
304 @classmethod
305 def get_expected_outputs(cls, outdir=None, label=None):
306 """Get lists of the expected outputs directories and files.
308 These are used by :code:`bilby_pipe` when transferring files via HTCondor.
310 Parameters
311 ----------
312 outdir : str
313 The output directory.
314 label : str
315 The label for the run.
317 Returns
318 -------
319 list
320 List of file names.
321 list
322 List of directory names. Will always be empty for dynesty.
323 """
324 filenames = []
325 for kind in ["resume", "dynesty"]:
326 filename = os.path.join(outdir, f"{label}_{kind}.pickle")
327 filenames.append(filename)
328 return filenames, []
330 def _print_func(
331 self,
332 results,
333 niter,
334 ncall=None,
335 dlogz=None,
336 stop_val=None,
337 nbatch=None,
338 logl_min=-np.inf,
339 logl_max=np.inf,
340 *args,
341 **kwargs,
342 ):
343 """Replacing status update for dynesty.result.print_func"""
344 if "interval" in self.print_method:
345 _time = datetime.datetime.now()
346 if _time - self._last_print_time < self._print_interval:
347 return
348 else:
349 self._last_print_time = _time
351 # Add time in current run to overall sampling time
352 total_time = self.sampling_time + _time - self.start_time
354 # Remove fractional seconds
355 total_time_str = str(total_time).split(".")[0]
357 # Extract results at the current iteration.
358 loglstar = results.loglstar
359 delta_logz = results.delta_logz
360 logz = results.logz
361 logzvar = results.logzvar
362 nc = results.nc
363 bounditer = results.bounditer
364 eff = results.eff
366 # Adjusting outputs for printing.
367 if delta_logz > 1e6:
368 delta_logz = np.inf
369 if 0.0 <= logzvar <= 1e6:
370 logzerr = np.sqrt(logzvar)
371 else:
372 logzerr = np.nan
373 if logz <= -1e6:
374 logz = -np.inf
375 if loglstar <= -1e6:
376 loglstar = -np.inf
378 if self.use_ratio:
379 key = "logz-ratio"
380 else:
381 key = "logz"
383 # Constructing output.
384 string = list()
385 string.append(f"bound:{bounditer:d}")
386 string.append(f"nc:{nc:3d}")
387 string.append(f"ncall:{ncall:.1e}")
388 string.append(f"eff:{eff:0.1f}%")
389 string.append(f"{key}={logz:0.2f}+/-{logzerr:0.2f}")
390 if nbatch is not None:
391 string.append(f"batch:{nbatch}")
392 if logl_min > -np.inf:
393 string.append(f"logl:{logl_min:.1f} < {loglstar:.1f} < {logl_max:.1f}")
394 if dlogz is not None:
395 string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}")
396 else:
397 string.append(f"stop:{stop_val:6.3f}")
398 string = " ".join(string)
400 if self.print_method == "tqdm":
401 self.pbar.set_postfix_str(string, refresh=False)
402 self.pbar.update(niter - self.pbar.n)
403 else:
404 print(f"{niter}it [{total_time_str} {string}]", file=sys.stdout, flush=True)
406 def _apply_dynesty_boundaries(self, key):
407 # The periodic kwargs passed into dynesty allows the parameters to
408 # wander out of the bounds, this includes both periodic and reflective.
409 # these are then handled in the prior_transform
410 selected = list()
411 for ii, param in enumerate(self.search_parameter_keys):
412 if self.priors[param].boundary == key:
413 logger.debug(f"Setting {key} boundary for {param}")
414 selected.append(ii)
415 if len(selected) == 0:
416 selected = None
417 self.kwargs[key] = selected
419 def nestcheck_data(self, out_file):
420 import nestcheck.data_processing
422 ns_run = nestcheck.data_processing.process_dynesty_run(out_file)
423 nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle"
424 safe_file_dump(ns_run, nestcheck_result, "pickle")
426 @property
427 def nlive(self):
428 return self.kwargs["nlive"]
430 @property
431 def sampler_init(self):
432 from dynesty import NestedSampler
434 return NestedSampler
436 @property
437 def sampler_class(self):
438 from dynesty.sampler import Sampler
440 return Sampler
442 def _set_sampling_method(self):
443 """
444 Resolve the sampling method and sampler to use from the provided
445 :code:`bound` and :code:`sample` arguments.
447 This requires registering the :code:`bilby` specific methods in the
448 appropriate locations within :code:`dynesty`.
450 Additionally, some combinations of bound/sample/proposals are not
451 compatible and so we either warn the user or raise an error.
452 """
453 import dynesty
455 _set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))
457 sample = self.kwargs["sample"]
458 bound = self.kwargs["bound"]
460 if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [
461 "live",
462 "live-multi",
463 ]:
464 logger.info(
465 "Live-point based bound method requested with dynesty sample "
466 f"'{sample}', overwriting to 'multi'"
467 )
468 self.kwargs["bound"] = "multi"
469 elif bound == "live":
470 from .dynesty_utils import LivePointSampler
472 dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler
473 elif bound == "live-multi":
474 from .dynesty_utils import MultiEllipsoidLivePointSampler
476 dynesty.dynamicsampler._SAMPLERS[
477 "live-multi"
478 ] = MultiEllipsoidLivePointSampler
479 elif sample == "acceptance-walk":
480 raise DynestySetupError(
481 "bound must be set to live or live-multi for sample=acceptance-walk"
482 )
483 elif self.proposals is None:
484 logger.warning(
485 "No proposals specified using dynesty sampling, defaulting "
486 "to 'volumetric'."
487 )
488 self.proposals = ["volumetric"]
489 _SamplingContainer.proposals = self.proposals
490 elif "diff" in self.proposals:
491 raise DynestySetupError(
492 "bound must be set to live or live-multi to use differential "
493 "evolution proposals"
494 )
496 if sample == "rwalk":
497 logger.info(
498 f"Using the bilby-implemented {sample} sample method with ACT estimated walks. "
499 f"An average of {2 * self.nact} steps will be accepted up to chain length "
500 f"{self.maxmcmc}."
501 )
502 from .dynesty_utils import AcceptanceTrackingRWalk
504 if self.kwargs["walks"] > self.maxmcmc:
505 raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
506 if self.nact < 1:
507 raise DynestySetupError("Unable to run with nact < 1")
508 AcceptanceTrackingRWalk.old_act = None
509 dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk()
510 elif sample == "acceptance-walk":
511 logger.info(
512 f"Using the bilby-implemented {sample} sampling with an average of "
513 f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
514 )
515 from .dynesty_utils import FixedRWalk
517 dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk()
518 elif sample == "act-walk":
519 logger.info(
520 f"Using the bilby-implemented {sample} sampling tracking the "
521 f"autocorrelation function and thinning by "
522 f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
523 )
524 from .dynesty_utils import ACTTrackingRWalk
526 ACTTrackingRWalk._cache = list()
527 dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk()
528 elif sample == "rwalk_dynesty":
529 sample = sample.strip("_dynesty")
530 self.kwargs["sample"] = sample
531 logger.info(f"Using the dynesty-implemented {sample} sample method")
533 @signal_wrapper
534 def run_sampler(self):
535 import dynesty
537 logger.info(f"Using dynesty version {dynesty.__version__}")
539 self._set_sampling_method()
540 self._setup_pool()
542 if self.resume:
543 self.resume = self.read_saved_state(continuing=True)
545 if self.resume:
546 logger.info("Resume file successfully loaded.")
547 else:
548 if self.kwargs["live_points"] is None:
549 self.kwargs["live_points"] = self.get_initial_points_from_prior(
550 self.nlive
551 )
552 self.kwargs["live_points"] = (*self.kwargs["live_points"], None)
553 self.sampler = self.sampler_init(
554 loglikelihood=_log_likelihood_wrapper,
555 prior_transform=_prior_transform_wrapper,
556 ndim=self.ndim,
557 **self.sampler_init_kwargs,
558 )
559 if self.print_method == "tqdm" and self.kwargs["print_progress"]:
560 from tqdm.auto import tqdm
562 self.pbar = tqdm(file=sys.stdout, initial=self.sampler.it)
564 self.start_time = datetime.datetime.now()
565 if self.check_point:
566 out = self._run_external_sampler_with_checkpointing()
567 else:
568 out = self._run_external_sampler_without_checkpointing()
569 self._update_sampling_time()
571 self._close_pool()
573 # Flushes the output to force a line break
574 if self.pbar is not None:
575 self.pbar = self.pbar.close()
576 print("")
578 check_directory_exists_and_if_not_mkdir(self.outdir)
580 if self.nestcheck:
581 self.nestcheck_data(out)
583 dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle"
584 safe_file_dump(out, dynesty_result, "dill")
586 self._generate_result(out)
587 self.result.sampling_time = self.sampling_time
589 return self.result
591 def _setup_pool(self):
592 """
593 In addition to the usual steps, we need to set the sampling kwargs on
594 every process. To make sure we get every process, run the kwarg setting
595 more times than we have processes.
596 """
597 super(Dynesty, self)._setup_pool()
598 if self.pool is not None:
599 args = (
600 [(self.nact, self.maxmcmc, self.proposals, self.naccept)]
601 * self.npool
602 * 10
603 )
604 self.pool.map(_set_sampling_kwargs, args)
606 def _generate_result(self, out):
607 """
608 Extract the information we need from the dynesty output. This includes
609 the evidence, nested samples, run statistics. In addition, we generate
610 the posterior samples from the nested samples.
612 Parameters
613 ==========
614 out: dynesty.result.Result
615 The dynesty output.
616 """
617 import dynesty
618 from scipy.special import logsumexp
620 from ..utils.random import rng
622 logwts = out["logwt"]
623 weights = np.exp(logwts - out["logz"][-1])
624 nested_samples = DataFrame(out.samples, columns=self.search_parameter_keys)
625 nested_samples["weights"] = weights
626 nested_samples["log_likelihood"] = out.logl
627 self.result.nested_samples = nested_samples
628 if self.rejection_sample_posterior:
629 keep = weights > rng.uniform(0, max(weights), len(weights))
630 self.result.samples = out.samples[keep]
631 self.result.log_likelihood_evaluations = out.logl[keep]
632 logger.info(
633 f"Rejection sampling nested samples to obtain {sum(keep)} posterior samples"
634 )
635 else:
636 self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
637 self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
638 unsorted_loglikelihoods=out.logl,
639 unsorted_samples=out.samples,
640 sorted_samples=self.result.samples,
641 )
642 logger.info("Resampling nested samples to posterior samples in place.")
643 self.result.log_evidence = out.logz[-1]
644 self.result.log_evidence_err = out.logzerr[-1]
645 self.result.information_gain = out.information[-1]
646 self.result.num_likelihood_evaluations = getattr(self.sampler, "ncall", 0)
648 logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2)
649 neffsamples = int(np.exp(logneff))
650 self.result.meta_data["run_statistics"] = dict(
651 nlikelihood=self.result.num_likelihood_evaluations,
652 neffsamples=neffsamples,
653 sampling_time_s=self.sampling_time.seconds,
654 ncores=self.kwargs.get("queue_size", 1),
655 )
656 self.kwargs["rstate"] = None
658 def _update_sampling_time(self):
659 end_time = datetime.datetime.now()
660 self.sampling_time += end_time - self.start_time
661 self.start_time = end_time
663 def _run_external_sampler_without_checkpointing(self):
664 logger.debug("Running sampler without checkpointing")
665 self.sampler.run_nested(**self.sampler_function_kwargs)
666 return self.sampler.results
668 def finalize_sampler_kwargs(self, sampler_kwargs):
669 sampler_kwargs["maxcall"] = self.n_check_point
670 sampler_kwargs["add_live"] = True
672 def _run_external_sampler_with_checkpointing(self):
673 """
674 In order to access the checkpointing, we run the sampler for short
675 periods of time (less than the checkpoint time) and if sufficient
676 time has passed, write a checkpoint before continuing. To get the most
677 informative checkpoint plots, the current live points are added to the
678 chain of nested samples within dynesty and have to be removed before
679 restarting the sampler.
680 """
682 logger.debug("Running sampler with checkpointing")
684 old_ncall = self.sampler.ncall
685 sampler_kwargs = self.sampler_function_kwargs.copy()
686 warnings.filterwarnings(
687 "ignore",
688 message="The sampling was stopped short due to maxiter/maxcall limit*",
689 category=UserWarning,
690 module="dynesty.sampler",
691 )
692 while True:
693 self.finalize_sampler_kwargs(sampler_kwargs)
694 if getattr(self.sampler, "added_live", False):
695 self.sampler._remove_live_points()
696 self.sampler.run_nested(**sampler_kwargs)
697 if self.sampler.ncall == old_ncall:
698 break
699 old_ncall = self.sampler.ncall
701 if os.path.isfile(self.resume_file):
702 last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
703 else:
704 last_checkpoint_s = (
705 datetime.datetime.now() - self.start_time
706 ).total_seconds()
707 if last_checkpoint_s > self.check_point_delta_t:
708 self.write_current_state()
709 self.plot_current_state()
710 if getattr(self.sampler, "added_live", False):
711 self.sampler._remove_live_points()
713 self.sampler.run_nested(**sampler_kwargs)
714 self.write_current_state()
715 self.plot_current_state()
716 return self.sampler.results
718 def _remove_checkpoint(self):
719 """Remove checkpointed state"""
720 if os.path.isfile(self.resume_file):
721 os.remove(self.resume_file)
723 def read_saved_state(self, continuing=False):
724 """
725 Read a pickled saved state of the sampler to disk.
727 If the live points are present and the run is continuing
728 they are removed.
729 The random state must be reset, as this isn't saved by the pickle.
730 `nqueue` is set to a negative number to trigger the queue to be
731 refilled before the first iteration.
732 The previous run time is set to self.
734 Parameters
735 ==========
736 continuing: bool
737 Whether the run is continuing or terminating, if True, the loaded
738 state is mostly written back to disk.
739 """
740 import dill
741 from dynesty import __version__ as dynesty_version
743 from ... import __version__ as bilby_version
745 versions = dict(bilby=bilby_version, dynesty=dynesty_version)
746 if os.path.isfile(self.resume_file):
747 logger.info(f"Reading resume file {self.resume_file}")
748 with open(self.resume_file, "rb") as file:
749 try:
750 sampler = dill.load(file)
751 except EOFError:
752 sampler = None
754 if not hasattr(sampler, "versions"):
755 logger.warning(
756 f"The resume file {self.resume_file} is corrupted or "
757 "the version of bilby has changed between runs. This "
758 "resume file will be ignored."
759 )
760 return False
761 version_warning = (
762 "The {code} version has changed between runs. "
763 "This may cause unpredictable behaviour and/or failure. "
764 "Old version = {old}, new version = {new}."
765 )
766 for code in versions:
767 if not versions[code] == sampler.versions.get(code, None):
768 logger.warning(
769 version_warning.format(
770 code=code,
771 old=sampler.versions.get(code, "None"),
772 new=versions[code],
773 )
774 )
775 del sampler.versions
776 self.sampler = sampler
777 if getattr(self.sampler, "added_live", False) and continuing:
778 self.sampler._remove_live_points()
779 self.sampler.nqueue = -1
780 self.start_time = self.sampler.kwargs.pop("start_time")
781 self.sampling_time = self.sampler.kwargs.pop("sampling_time")
782 self.sampler.queue_size = self.kwargs["queue_size"]
783 self.sampler.pool = self.pool
784 if self.pool is not None:
785 self.sampler.M = self.pool.map
786 else:
787 self.sampler.M = map
788 return True
789 else:
790 logger.info(f"Resume file {self.resume_file} does not exist.")
791 return False
793 def write_current_state_and_exit(self, signum=None, frame=None):
794 if self.pbar is not None:
795 self.pbar = self.pbar.close()
796 super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame)
798 def write_current_state(self):
799 """
800 Write the current state of the sampler to disk.
802 The sampler is pickle dumped using `dill`.
803 The sampling time is also stored to get the full CPU time for the run.
805 The check of whether the sampler is picklable is to catch an error
806 when using pytest. Hopefully, this message won't be triggered during
807 normal running.
808 """
810 import dill
811 from dynesty import __version__ as dynesty_version
813 from ... import __version__ as bilby_version
815 if getattr(self, "sampler", None) is None:
816 # Sampler not initialized, not able to write current state
817 return
819 check_directory_exists_and_if_not_mkdir(self.outdir)
820 if hasattr(self, "start_time"):
821 self._update_sampling_time()
822 self.sampler.kwargs["sampling_time"] = self.sampling_time
823 self.sampler.kwargs["start_time"] = self.start_time
824 self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version)
825 self.sampler.pool = None
826 self.sampler.M = map
827 if dill.pickles(self.sampler):
828 safe_file_dump(self.sampler, self.resume_file, dill)
829 logger.info(f"Written checkpoint file {self.resume_file}")
830 else:
831 logger.warning(
832 "Cannot write pickle resume file! "
833 "Job will not resume if interrupted."
834 )
835 self.sampler.pool = self.pool
836 if self.sampler.pool is not None:
837 self.sampler.M = self.sampler.pool.map
839 def dump_samples_to_dat(self):
840 """
841 Save the current posterior samples to a space-separated plain-text
842 file. These are unbiased posterior samples, however, there will not
843 be many of them until the analysis is nearly over.
844 """
845 sampler = self.sampler
846 ln_weights = sampler.saved_logwt - sampler.saved_logz[-1]
848 weights = np.exp(ln_weights)
849 samples = rejection_sample(np.array(sampler.saved_v), weights)
850 nsamples = len(samples)
852 # If we don't have enough samples, don't dump them
853 if nsamples < 100:
854 return
856 filename = f"{self.outdir}/{self.label}_samples.dat"
857 logger.info(f"Writing {nsamples} current samples to {filename}")
859 df = DataFrame(samples, columns=self.search_parameter_keys)
860 df.to_csv(filename, index=False, header=True, sep=" ")
862 def plot_current_state(self):
863 """
864 Make diagonstic plots of the history and current state of the sampler.
866 These plots are a mixture of :code:`dynesty` implemented run and trace
867 plots and our custom stats plot. We also make a copy of the trace plot
868 using the unit hypercube samples to reflect the internal state of the
869 sampler.
871 Any errors during plotting should be handled so that sampling can
872 continue.
873 """
874 if self.check_point_plot:
875 import dynesty.plotting as dyplot
876 import matplotlib.pyplot as plt
878 labels = [label.replace("_", " ") for label in self.search_parameter_keys]
879 try:
880 filename = f"{self.outdir}/{self.label}_checkpoint_trace.png"
881 fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
882 fig.tight_layout()
883 fig.savefig(filename)
884 except (
885 RuntimeError,
886 np.linalg.linalg.LinAlgError,
887 ValueError,
888 OverflowError,
889 ) as e:
890 logger.warning(e)
891 logger.warning("Failed to create dynesty state plot at checkpoint")
892 except Exception as e:
893 logger.warning(
894 f"Unexpected error {e} in dynesty plotting. "
895 "Please report at git.ligo.org/lscsoft/bilby/-/issues"
896 )
897 finally:
898 plt.close("all")
899 try:
900 filename = f"{self.outdir}/{self.label}_checkpoint_trace_unit.png"
901 from copy import deepcopy
903 from dynesty.utils import results_substitute
905 temp = deepcopy(self.sampler.results)
906 temp = results_substitute(temp, dict(samples=temp["samples_u"]))
907 fig = dyplot.traceplot(temp, labels=labels)[0]
908 fig.tight_layout()
909 fig.savefig(filename)
910 except (
911 RuntimeError,
912 np.linalg.linalg.LinAlgError,
913 ValueError,
914 OverflowError,
915 ) as e:
916 logger.warning(e)
917 logger.warning("Failed to create dynesty unit state plot at checkpoint")
918 except Exception as e:
919 logger.warning(
920 f"Unexpected error {e} in dynesty plotting. "
921 "Please report at git.ligo.org/lscsoft/bilby/-/issues"
922 )
923 finally:
924 plt.close("all")
925 try:
926 filename = f"{self.outdir}/{self.label}_checkpoint_run.png"
927 fig, _ = dyplot.runplot(
928 self.sampler.results, logplot=False, use_math_text=False
929 )
930 fig.tight_layout()
931 plt.savefig(filename)
932 except (
933 RuntimeError,
934 np.linalg.linalg.LinAlgError,
935 ValueError,
936 OverflowError,
937 ) as e:
938 logger.warning(e)
939 logger.warning("Failed to create dynesty run plot at checkpoint")
940 except Exception as e:
941 logger.warning(
942 f"Unexpected error {e} in dynesty plotting. "
943 "Please report at git.ligo.org/lscsoft/bilby/-/issues"
944 )
945 finally:
946 plt.close("all")
947 try:
948 filename = f"{self.outdir}/{self.label}_checkpoint_stats.png"
949 fig, _ = dynesty_stats_plot(self.sampler)
950 fig.tight_layout()
951 plt.savefig(filename)
952 except (RuntimeError, ValueError, OverflowError) as e:
953 logger.warning(e)
954 logger.warning("Failed to create dynesty stats plot at checkpoint")
955 except DynestySetupError:
956 logger.debug("Cannot create Dynesty stats plot with dynamic sampler.")
957 except Exception as e:
958 logger.warning(
959 f"Unexpected error {e} in dynesty plotting. "
960 "Please report at git.ligo.org/lscsoft/bilby/-/issues"
961 )
962 finally:
963 plt.close("all")
965 def _run_test(self):
966 """Run the sampler very briefly as a sanity test that it works."""
967 import pandas as pd
969 self._set_sampling_method()
970 self._setup_pool()
971 self.sampler = self.sampler_init(
972 loglikelihood=_log_likelihood_wrapper,
973 prior_transform=_prior_transform_wrapper,
974 ndim=self.ndim,
975 **self.sampler_init_kwargs,
976 )
977 sampler_kwargs = self.sampler_function_kwargs.copy()
978 sampler_kwargs["maxiter"] = 2
980 if self.print_method == "tqdm" and self.kwargs["print_progress"]:
981 from tqdm.auto import tqdm
983 self.pbar = tqdm(file=sys.stdout, initial=self.sampler.it)
984 self.sampler.run_nested(**sampler_kwargs)
985 self._close_pool()
987 if self.pbar is not None:
988 self.pbar = self.pbar.close()
989 print("")
990 N = 100
991 self.result.samples = pd.DataFrame(self.priors.sample(N))[
992 self.search_parameter_keys
993 ].values
994 self.result.nested_samples = self.result.samples
995 self.result.log_likelihood_evaluations = np.ones(N)
996 self.result.log_evidence = 1
997 self.result.log_evidence_err = 0.1
999 return self.result
1001 def prior_transform(self, theta):
1002 """Prior transform method that is passed into the external sampler.
1003 cube we map this back to [0, 1].
1005 Parameters
1006 ==========
1007 theta: list
1008 List of sampled values on a unit interval
1010 Returns
1011 =======
1012 list: Properly rescaled sampled values
1014 """
1015 return self.priors.rescale(self._search_parameter_keys, theta)
1018@latex_plot_format
1019def dynesty_stats_plot(sampler):
1020 """
1021 Plot diagnostic statistics from a dynesty run
1023 The plotted quantities per iteration are:
1025 - nc: the number of likelihood calls
1026 - scale: the number of accepted MCMC steps if using :code:`bound="live"`
1027 or :code:`bound="live-multi"`, otherwise, the scale applied to the MCMC
1028 steps
1029 - lifetime: the number of iterations a point stays in the live set
1031 There is also a histogram of the lifetime compared with the theoretical
1032 distribution. To avoid edge effects, we discard the first 6 * nlive
1034 Parameters
1035 ----------
1036 sampler: dynesty.sampler.Sampler
1037 The sampler object containing the run history.
1039 Returns
1040 -------
1041 fig: matplotlib.pyplot.figure.Figure
1042 Figure handle for the new plot
1043 axs: matplotlib.pyplot.axes.Axes
1044 Axes handles for the new plot
1046 """
1047 import matplotlib.pyplot as plt
1048 from scipy.stats import geom, ks_1samp
1050 fig, axs = plt.subplots(nrows=4, figsize=(8, 8))
1051 data = sampler.saved_run.D
1052 for ax, name in zip(axs, ["nc", "scale"]):
1053 ax.plot(data[name], color="blue")
1054 ax.set_ylabel(name.title())
1055 lifetimes = np.arange(len(data["it"])) - data["it"]
1056 axs[-2].set_ylabel("Lifetime")
1057 if not hasattr(sampler, "nlive"):
1058 raise DynestySetupError("Cannot make stats plot for dynamic sampler.")
1059 nlive = sampler.nlive
1060 burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive))
1061 if len(data["it"]) > burn + sampler.nlive:
1062 axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey")
1063 axs[-2].plot(
1064 np.arange(burn, len(lifetimes) - nlive),
1065 lifetimes[burn:-nlive],
1066 color="blue",
1067 )
1068 axs[-2].plot(
1069 np.arange(len(lifetimes) - nlive, len(lifetimes)),
1070 lifetimes[-nlive:],
1071 color="red",
1072 )
1073 lifetimes = lifetimes[burn:-nlive]
1074 ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf)
1075 axs[-1].hist(
1076 lifetimes,
1077 bins=np.linspace(0, 6 * nlive, 60),
1078 histtype="step",
1079 density=True,
1080 color="blue",
1081 label=f"p value = {ks_result.pvalue:.3f}",
1082 )
1083 axs[-1].plot(
1084 np.arange(1, 6 * nlive),
1085 geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)),
1086 color="red",
1087 )
1088 axs[-1].set_xlim(0, 6 * nlive)
1089 axs[-1].legend()
1090 axs[-1].set_yscale("log")
1091 else:
1092 axs[-2].plot(
1093 np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey"
1094 )
1095 axs[-2].plot(
1096 np.arange(len(lifetimes) - nlive, len(lifetimes)),
1097 lifetimes[-nlive:],
1098 color="red",
1099 )
1100 axs[-2].set_yscale("log")
1101 axs[-2].set_xlabel("Iteration")
1102 axs[-1].set_xlabel("Lifetime")
1103 return fig, axs
1106class DynestySetupError(Exception):
1107 pass