Coverage for bilby/bilby_mcmc/sampler.py: 57%
752 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import os
3import time
4from collections import Counter
5from pathlib import Path
7import numpy as np
8import pandas as pd
9from scipy.optimize import differential_evolution
11from ..core.result import rejection_sample
12from ..core.sampler.base_sampler import (
13 MCMCSampler,
14 ResumeError,
15 SamplerError,
16 _sampling_convenience_dump,
17 signal_wrapper,
18)
19from ..core.utils import (
20 check_directory_exists_and_if_not_mkdir,
21 logger,
22 random,
23 safe_file_dump,
24)
25from . import proposals
26from .chain import Chain, Sample
27from .utils import LOGLKEY, LOGPKEY, ConvergenceInputs, ParallelTemperingInputs
30class Bilby_MCMC(MCMCSampler):
31 """The built-in Bilby MCMC sampler
33 Parameters
34 ----------
35 likelihood: likelihood.Likelihood
36 A object with a log_l method
37 priors: bilby.core.prior.PriorDict, dict
38 Priors to be used in the search.
39 This has attributes for each parameter to be sampled.
40 outdir: str, optional
41 Name of the output directory
42 label: str, optional
43 Naming scheme of the output files
44 use_ratio: bool, optional
45 Switch to set whether or not you want to use the log-likelihood ratio
46 or just the log-likelihood
47 skip_import_verification: bool
48 Skips the check if the sampler is installed if true. This is
49 only advisable for testing environments
50 check_point_plot: bool
51 If true, create plots at the check point
52 check_point_delta_t: float
53 The time in seconds afterwhich to checkpoint (defaults to 30 minutes)
54 diagnostic: bool
55 If true, create deep-diagnostic plots used for checking convergence
56 problems.
57 resume: bool
58 If true, resume from any existing check point files
59 exit_code: int
60 The code on which to raise if exiting
61 nsamples: int (1000)
62 The number of samples to draw
63 nensemble: int (1)
64 The number of ensemble-chains to run (with periodic communication)
65 pt_ensemble: bool (False)
66 If true, each run a parallel-tempered set of chains for each
67 ensemble-chain (in which case the total number of chains is
68 nensemble * ntemps). Else, only the zero-ensemble chain is run with a
69 parallel-tempering (in which case the total number of chains is
70 nensemble + ntemps - 1).
71 ntemps: int (1)
72 The number of parallel-tempered chains to run
73 Tmax: float, (None)
74 If given, the maximum temperature to set the initial temperate-ladder
75 Tmax_from_SNR: float (20)
76 (Alternative to Tmax): The SNR to estimate an appropriate Tmax from.
77 initial_betas: list (None)
78 (Alternative to Tmax and Tmax_from_SNR): If given, an initial choice of
79 the inverse temperature ladder.
80 pt_rejection_sample: bool (False)
81 If true, use rejection sampling to draw samples from the pt-chains.
82 adapt, adapt_t0, adapt_nu: bool, float, float (True, 100, 10)
83 Whether to use adaptation and the adaptation parameters.
84 See arXiv:1501.05823 for a description of adapt_t0 and adapt_nu.
85 burn_in_nact, thin_by_nact, fixed_discard: float, float, float (10, 1, 0)
86 The number of auto-correlation times to discard for burn-in and to
87 thin by. The fixed_discard is the number of steps discarded before
88 automatic autocorrelation time analysis begins.
89 autocorr_c: float (5)
90 The step-size for the window search. See emcee.autocorr.integrated_time
91 for additional details.
92 L1steps: int
93 The number of internal steps to take. Improves the scaling performance
94 of multiprocessing. Note, all ACTs are calculated based on the saved
95 steps. So, the total ACT (or number of steps) is L1steps * tau
96 (or L1steps * position).
97 L2steps: int
98 The number of steps to take before swapping between parallel-tempered
99 and ensemble chains.
100 npool: int
101 The number of multiprocessing cores to use. For efficiency, this must be
102 matched to an integer number of the total number of chains.
103 printdt: float
104 Print an update on the progress every printdt s. Note, each print
105 requires an evaluation of the ACT so short print times are unwise.
106 min_tau: 1
107 The minimum allowed ACT. Can be used to force a larger ACT.
108 proposal_cycle: str, bilby.core.sampler.bilby_mcmc.proposals.ProposalCycle
109 Either a string pointing to one of the built-in proposal cycles or,
110 a proposal cycle.
111 stop_after_convergence:
112 If running with parallel-tempered chains. Stop updating the chains once
113 they have congerged. After this time, random samples will be drawn at
114 swap time.
115 fixed_tau: int
116 A fixed value for the ACT: used for testing purposes.
117 tau_window: int, None
118 Using tau', a previous estimates of tau, calculate the new tau using
119 the last tau_window * tau' steps. If None, the entire chain is used.
120 evidence_method: str, [stepping_stone, thermodynamic]
121 The evidence calculation method to use. Defaults to stepping_stone, but
122 the results of all available methods are stored in the ln_z_dict.
123 initial_sample_method: str
124 Method to draw the initial sample. Either "prior" (a random draw
125 from the prior) or "maximize" (use an optimization approach to attempt
126 to find the maximum posterior estimate).
127 initial_sample_dict: dict
128 A dictionary of the initial sample value. If incomplete, will overwrite
129 the initial_sample drawn using initial_sample_method.
130 normalize_prior: bool
131 When False, disables calculation of constraint normalization factor
132 during prior probability computation. Default value is True.
133 verbose: bool
134 Whether to print diagnostic output during the run.
136 """
138 default_kwargs = dict(
139 nsamples=1000,
140 nensemble=1,
141 pt_ensemble=False,
142 ntemps=1,
143 Tmax=None,
144 Tmax_from_SNR=20,
145 initial_betas=None,
146 adapt=True,
147 adapt_t0=100,
148 adapt_nu=10,
149 pt_rejection_sample=False,
150 burn_in_nact=10,
151 thin_by_nact=1,
152 fixed_discard=0,
153 autocorr_c=5,
154 L1steps=100,
155 L2steps=3,
156 printdt=60,
157 check_point_delta_t=1800,
158 min_tau=1,
159 proposal_cycle="default",
160 stop_after_convergence=False,
161 fixed_tau=None,
162 tau_window=None,
163 evidence_method="stepping_stone",
164 initial_sample_method="prior",
165 initial_sample_dict=None,
166 )
168 def __init__(
169 self,
170 likelihood,
171 priors,
172 outdir="outdir",
173 label="label",
174 use_ratio=False,
175 skip_import_verification=True,
176 check_point_plot=True,
177 diagnostic=False,
178 resume=True,
179 exit_code=130,
180 verbose=True,
181 normalize_prior=True,
182 **kwargs,
183 ):
185 super(Bilby_MCMC, self).__init__(
186 likelihood=likelihood,
187 priors=priors,
188 outdir=outdir,
189 label=label,
190 use_ratio=use_ratio,
191 skip_import_verification=skip_import_verification,
192 exit_code=exit_code,
193 **kwargs,
194 )
196 self.check_point_plot = check_point_plot
197 self.diagnostic = diagnostic
198 self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
199 self.L1steps = self.kwargs["L1steps"]
200 self.L2steps = self.kwargs["L2steps"]
201 self.normalize_prior = normalize_prior
202 self.pt_inputs = ParallelTemperingInputs(
203 **{key: self.kwargs[key] for key in ParallelTemperingInputs._fields}
204 )
205 self.convergence_inputs = ConvergenceInputs(
206 **{key: self.kwargs[key] for key in ConvergenceInputs._fields}
207 )
208 self.proposal_cycle = self.kwargs["proposal_cycle"]
209 self.pt_rejection_sample = self.kwargs["pt_rejection_sample"]
210 self.evidence_method = self.kwargs["evidence_method"]
211 self.initial_sample_method = self.kwargs["initial_sample_method"]
212 self.initial_sample_dict = self.kwargs["initial_sample_dict"]
214 self.printdt = self.kwargs["printdt"]
215 self.check_point_delta_t = self.kwargs["check_point_delta_t"]
216 check_directory_exists_and_if_not_mkdir(self.outdir)
217 self.resume = resume
218 self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label)
220 self.verify_configuration()
221 self.verbose = verbose
223 def verify_configuration(self):
224 if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1:
225 logger.warning("Burn-in inefficiency fraction greater than 10%")
227 def _translate_kwargs(self, kwargs):
228 kwargs = super()._translate_kwargs(kwargs)
229 if "printdt" not in kwargs:
230 for equiv in ["print_dt", "print_update"]:
231 if equiv in kwargs:
232 kwargs["printdt"] = kwargs.pop(equiv)
233 if "npool" not in kwargs:
234 for equiv in self.npool_equiv_kwargs:
235 if equiv in kwargs:
236 kwargs["npool"] = kwargs.pop(equiv)
237 if "check_point_delta_t" not in kwargs:
238 for equiv in self.check_point_equiv_kwargs:
239 if equiv in kwargs:
240 kwargs["check_point_delta_t"] = kwargs.pop(equiv)
242 @property
243 def target_nsamples(self):
244 return self.kwargs["target_nsamples"]
246 @signal_wrapper
247 def run_sampler(self):
248 self._setup_pool()
249 self.setup_chain_set()
250 self.start_time = datetime.datetime.now()
251 self.draw()
252 self._close_pool()
253 self.check_point(ignore_time=True)
255 self.result = self.add_data_to_result(
256 result=self.result,
257 ptsampler=self.ptsampler,
258 outdir=self.outdir,
259 label=self.label,
260 make_plots=self.check_point_plot,
261 )
263 return self.result
265 @staticmethod
266 def add_data_to_result(result, ptsampler, outdir, label, make_plots):
267 result.samples = ptsampler.samples
268 result.log_likelihood_evaluations = result.samples[LOGLKEY].to_numpy()
269 result.log_prior_evaluations = result.samples[LOGPKEY].to_numpy()
270 ptsampler.compute_evidence(
271 outdir=outdir,
272 label=label,
273 make_plots=make_plots,
274 )
275 result.log_evidence = ptsampler.ln_z
276 result.log_evidence_err = ptsampler.ln_z_err
277 result.sampling_time = datetime.timedelta(seconds=ptsampler.sampling_time)
278 result.meta_data["bilby_mcmc"] = dict(
279 tau=ptsampler.tau,
280 convergence_inputs=ptsampler.convergence_inputs._asdict(),
281 pt_inputs=ptsampler.pt_inputs._asdict(),
282 total_steps=ptsampler.position,
283 nsamples=ptsampler.nsamples,
284 )
285 if ptsampler.pool is not None:
286 npool = ptsampler.pool._processes
287 else:
288 npool = 1
289 result.meta_data["run_statistics"] = dict(
290 nlikelihood=ptsampler.position * ptsampler.L1steps * ptsampler._nsamplers,
291 neffsamples=ptsampler.nsamples * ptsampler.convergence_inputs.thin_by_nact,
292 sampling_time_s=result.sampling_time.seconds,
293 ncores=npool,
294 )
296 return result
298 def setup_chain_set(self):
299 if self.read_current_state() and self.resume is True:
300 self.ptsampler.pool = self.pool
301 else:
302 self.init_ptsampler()
304 def init_ptsampler(self):
306 logger.info(f"Initializing BilbyPTMCMCSampler with:\n{self.get_setup_string()}")
307 self.ptsampler = BilbyPTMCMCSampler(
308 convergence_inputs=self.convergence_inputs,
309 pt_inputs=self.pt_inputs,
310 proposal_cycle=self.proposal_cycle,
311 pt_rejection_sample=self.pt_rejection_sample,
312 pool=self.pool,
313 use_ratio=self.use_ratio,
314 evidence_method=self.evidence_method,
315 initial_sample_method=self.initial_sample_method,
316 initial_sample_dict=self.initial_sample_dict,
317 normalize_prior=self.normalize_prior,
318 )
320 def get_setup_string(self):
321 string = (
322 f" Convergence settings: {self.convergence_inputs}\n"
323 f" Parallel-tempering settings: {self.pt_inputs}\n"
324 f" proposal_cycle: {self.proposal_cycle}\n"
325 f" pt_rejection_sample: {self.pt_rejection_sample}"
326 )
327 return string
329 def draw(self):
330 self._steps_since_last_print = 0
331 self._time_since_last_print = 0
332 logger.info(f"Drawing {self.target_nsamples} samples")
333 logger.info(f"Checkpoint every check_point_delta_t={self.check_point_delta_t}s")
334 logger.info(f"Print update every printdt={self.printdt}s")
336 while True:
337 t0 = datetime.datetime.now()
338 self.ptsampler.step_all_chains()
339 dt = (datetime.datetime.now() - t0).total_seconds()
340 self.ptsampler.sampling_time += dt
341 self._time_since_last_print += dt
342 self._steps_since_last_print += self.ptsampler.L1steps
344 if self._time_since_last_print > self.printdt:
345 tp0 = datetime.datetime.now()
346 self.print_progress()
347 tp = datetime.datetime.now()
348 ppt_frac = (tp - tp0).total_seconds() / self._time_since_last_print
349 if ppt_frac > 0.01:
350 logger.warning(
351 f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})"
352 )
353 self._steps_since_last_print = 0
354 self._time_since_last_print = 0
356 self.check_point()
358 if self.ptsampler.nsamples_last >= self.target_nsamples:
359 # Perform a second check without cached values
360 if self.ptsampler.nsamples_nocache >= self.target_nsamples:
361 logger.info("Reached convergence: exiting sampling")
362 break
364 def check_point(self, ignore_time=False):
365 tS = (datetime.datetime.now() - self.start_time).total_seconds()
366 if os.path.isfile(self.resume_file):
367 tR = time.time() - os.path.getmtime(self.resume_file)
368 else:
369 tR = np.inf
371 if ignore_time or np.min([tS, tR]) > self.check_point_delta_t:
372 logger.info("Checkpoint start")
373 self.write_current_state()
374 self.print_long_progress()
375 logger.info("Checkpoint finished")
377 def _remove_checkpoint(self):
378 """Remove checkpointed state"""
379 if os.path.isfile(self.resume_file):
380 os.remove(self.resume_file)
382 def read_current_state(self):
383 """Read the existing resume file
385 Returns
386 -------
387 success: boolean
388 If true, resume file was successfully loaded, otherwise false
390 """
391 if os.path.isfile(self.resume_file) is False or not os.path.getsize(
392 self.resume_file
393 ):
394 return False
395 import dill
397 with open(self.resume_file, "rb") as file:
398 ptsampler = dill.load(file)
399 if not isinstance(ptsampler, BilbyPTMCMCSampler):
400 logger.debug("Malformed resume file, ignoring")
401 return False
402 self.ptsampler = ptsampler
403 if self.ptsampler.pt_inputs != self.pt_inputs:
404 msg = (
405 f"pt_inputs has changed: {self.ptsampler.pt_inputs} "
406 f"-> {self.pt_inputs}"
407 )
408 raise ResumeError(msg)
409 self.ptsampler.set_convergence_inputs(self.convergence_inputs)
410 self.ptsampler.pt_rejection_sample = self.pt_rejection_sample
412 logger.info(
413 f"Loaded resume file {self.resume_file} "
414 f"with {self.ptsampler.position} steps "
415 f"setup:\n{self.get_setup_string()}"
416 )
417 return True
419 def write_current_state(self):
420 import dill
422 if not hasattr(self, "ptsampler"):
423 logger.debug("Attempted checkpoint before initialization")
424 return
425 logger.debug("Check point")
426 check_directory_exists_and_if_not_mkdir(self.outdir)
428 _pool = self.ptsampler.pool
429 self.ptsampler.pool = None
430 if dill.pickles(self.ptsampler):
431 safe_file_dump(self.ptsampler, self.resume_file, dill)
432 logger.info("Written checkpoint file {}".format(self.resume_file))
433 else:
434 logger.warning(
435 "Cannot write pickle resume file! Job may not resume if interrupted."
436 )
437 # Touch the file to postpone next check-point attempt
438 Path(self.resume_file).touch(exist_ok=True)
439 self.ptsampler.pool = _pool
441 def print_long_progress(self):
442 self.print_per_proposal()
443 self.print_tau_dict()
444 if self.ptsampler.ntemps > 1:
445 self.print_pt_acceptance()
446 if self.ptsampler.nensemble > 1:
447 self.print_ensemble_acceptance()
448 if self.check_point_plot:
449 self.plot_progress(
450 self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic
451 )
452 self.ptsampler.compute_evidence(
453 outdir=self.outdir, label=self.label, make_plots=True
454 )
456 def print_ensemble_acceptance(self):
457 logger.info(f"Ensemble swaps = {self.ptsampler.swap_counter['ensemble']}")
458 logger.info(self.ptsampler.ensemble_proposal_cycle)
460 def print_progress(self):
461 position = self.ptsampler.position
463 # Total sampling time
464 sampling_time = datetime.timedelta(seconds=self.ptsampler.sampling_time)
465 time = str(sampling_time).split(".")[0]
467 # Time for last evaluation set
468 time_per_eval_ms = (
469 1000 * self._time_since_last_print / self._steps_since_last_print
470 )
472 # Pull out progress summary
473 tau = self.ptsampler.tau
474 nsamples = self.ptsampler.nsamples
475 minimum_index = self.ptsampler.primary_sampler.chain.minimum_index
476 method = self.ptsampler.primary_sampler.chain.minimum_index_method
477 mindex_str = f"{minimum_index:0.2e}({method})"
478 alpha = self.ptsampler.primary_sampler.acceptance_ratio
479 maxl = self.ptsampler.primary_sampler.chain.max_log_likelihood
481 nlikelihood = position * self.L1steps * self.ptsampler._nsamplers
482 eff = 100 * nsamples / nlikelihood
484 # Estimated time til finish (ETF)
485 if tau < np.inf:
486 remaining_samples = self.target_nsamples - nsamples
487 remaining_evals = (
488 remaining_samples
489 * self.convergence_inputs.thin_by_nact
490 * tau
491 * self.L1steps
492 )
493 remaining_time_s = time_per_eval_ms * 1e-3 * remaining_evals
494 remaining_time_dt = datetime.timedelta(seconds=remaining_time_s)
495 if remaining_samples > 0:
496 remaining_time = str(remaining_time_dt).split(".")[0]
497 else:
498 remaining_time = "0"
499 else:
500 remaining_time = "-"
502 msg = (
503 f"{position:0.2e}|{time}|{mindex_str}|t={tau:0.0f}|"
504 f"n={nsamples:0.0f}|a={alpha:0.2f}|e={eff:0.1e}%|"
505 f"{time_per_eval_ms:0.2f}ms/ev|maxl={maxl:0.2f}|"
506 f"ETF={remaining_time}"
507 )
509 if self.pt_rejection_sample:
510 count = self.ptsampler.rejection_sampling_count
511 rse = 100 * count / nsamples
512 msg += f"|rse={rse:0.2f}%"
514 if self.verbose:
515 print(msg, flush=True)
517 def print_per_proposal(self):
518 logger.info("Zero-temperature proposals:")
519 for prop in self.ptsampler[0].proposal_cycle.proposal_list:
520 logger.info(prop)
522 def print_pt_acceptance(self):
523 logger.info(f"Temperature swaps = {self.ptsampler.swap_counter['temperature']}")
524 for column in self.ptsampler.sampler_list_of_tempered_lists:
525 for ii, sampler in enumerate(column):
526 total = sampler.pt_accepted + sampler.pt_rejected
527 beta = sampler.beta
528 if total > 0:
529 ratio = f"{sampler.pt_accepted / total:0.2f}"
530 else:
531 ratio = "-"
532 logger.info(
533 f"Temp:{ii}<->{ii+1}|"
534 f"beta={beta:0.4g}|"
535 f"hot-samp={sampler.nsamples}|"
536 f"swap={ratio}|"
537 f"conv={sampler.chain.converged}|"
538 )
540 def print_tau_dict(self):
541 msg = f"Current taus={self.ptsampler.primary_sampler.chain.tau_dict}"
542 logger.info(msg)
544 @staticmethod
545 def plot_progress(ptsampler, label, outdir, priors, diagnostic=False):
546 logger.info("Creating diagnostic plots")
547 for ii, row in ptsampler.sampler_dictionary.items():
548 for jj, sampler in enumerate(row):
549 plot_label = f"{label}_E{sampler.Eindex}_T{sampler.Tindex}"
550 if diagnostic is True or sampler.beta == 1:
551 sampler.chain.plot(
552 outdir=outdir,
553 label=plot_label,
554 priors=priors,
555 all_samples=ptsampler.samples,
556 )
558 @classmethod
559 def get_expected_outputs(cls, outdir=None, label=None):
560 """Get lists of the expected outputs directories and files.
562 These are used by :code:`bilby_pipe` when transferring files via HTCondor.
564 Parameters
565 ----------
566 outdir : str
567 The output directory.
568 label : str
569 The label for the run.
571 Returns
572 -------
573 list
574 List of file names.
575 list
576 List of directory names. Will always be empty for bilby_mcmc.
577 """
578 filenames = [os.path.join(outdir, f"{label}_resume.pickle")]
579 return filenames, []
582class BilbyPTMCMCSampler(object):
583 def __init__(
584 self,
585 convergence_inputs,
586 pt_inputs,
587 proposal_cycle,
588 pt_rejection_sample,
589 pool,
590 use_ratio,
591 evidence_method,
592 initial_sample_method,
593 initial_sample_dict,
594 normalize_prior=True,
595 ):
596 self.set_pt_inputs(pt_inputs)
597 self.use_ratio = use_ratio
598 self.initial_sample_method = initial_sample_method
599 self.initial_sample_dict = initial_sample_dict
600 self.normalize_prior = normalize_prior
601 self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
602 self.set_convergence_inputs(convergence_inputs)
603 self.pt_rejection_sample = pt_rejection_sample
604 self.pool = pool
605 self.evidence_method = evidence_method
607 # Initialize counters
608 self.swap_counter = Counter()
609 self.swap_counter["temperature"] = 0
610 self.swap_counter["L2-temperature"] = 0
611 self.swap_counter["ensemble"] = 0
612 self.swap_counter["L2-ensemble"] = int(self.L2steps / 2) + 1
614 self._nsamples_dict = {}
615 self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(
616 _sampling_convenience_dump.priors
617 )
618 self.sampling_time = 0
619 self.ln_z_dict = dict()
620 self.ln_z_err_dict = dict()
622 def get_initial_betas(self):
623 pt_inputs = self.pt_inputs
624 if self.ntemps == 1:
625 betas = np.array([1])
626 elif pt_inputs.initial_betas is not None:
627 betas = np.array(pt_inputs.initial_betas)
628 elif pt_inputs.Tmax is not None:
629 betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps)
630 elif pt_inputs.Tmax_from_SNR is not None:
631 ndim = len(_sampling_convenience_dump.priors.non_fixed_keys)
632 target_hot_likelihood = ndim / 2
633 Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood)
634 betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps)
635 else:
636 raise SamplerError("Unable to set temperature ladder from inputs")
638 if len(betas) != self.ntemps:
639 raise SamplerError("Temperatures do not match ntemps")
641 return betas
643 def setup_sampler_dictionary(self, convergence_inputs, proposal_cycle):
645 betas = self.get_initial_betas()
646 logger.info(
647 f"Initializing BilbyPTMCMCSampler with:"
648 f"ntemps={self.ntemps}, "
649 f"nensemble={self.nensemble}, "
650 f"pt_ensemble={self.pt_ensemble}, "
651 f"initial_betas={betas}, "
652 f"initial_sample_method={self.initial_sample_method}, "
653 f"initial_sample_dict={self.initial_sample_dict}\n"
654 )
655 self.sampler_dictionary = dict()
656 for Tindex, beta in enumerate(betas):
657 if beta == 1 or self.pt_ensemble:
658 n = self.nensemble
659 else:
660 n = 1
661 temp_sampler_list = [
662 BilbyMCMCSampler(
663 beta=beta,
664 Tindex=Tindex,
665 Eindex=Eindex,
666 convergence_inputs=convergence_inputs,
667 proposal_cycle=proposal_cycle,
668 use_ratio=self.use_ratio,
669 initial_sample_method=self.initial_sample_method,
670 initial_sample_dict=self.initial_sample_dict,
671 normalize_prior=self.normalize_prior,
672 )
673 for Eindex in range(n)
674 ]
675 self.sampler_dictionary[Tindex] = temp_sampler_list
677 # Store data
678 self._nsamplers = len(self.sampler_list)
680 @property
681 def sampler_list(self):
682 """A list of all individual samplers"""
683 return [s for item in self.sampler_dictionary.values() for s in item]
685 @sampler_list.setter
686 def sampler_list(self, sampler_list):
687 for sampler in sampler_list:
688 self.sampler_dictionary[sampler.Tindex][sampler.Eindex] = sampler
690 def sampler_list_by_column(self, column):
691 return [row[column] for row in self.sampler_dictionary.values()]
693 @property
694 def sampler_list_of_tempered_lists(self):
695 if self.pt_ensemble:
696 return [self.sampler_list_by_column(ii) for ii in range(self.nensemble)]
697 else:
698 return [self.sampler_list_by_column(0)]
700 @property
701 def tempered_sampler_list(self):
702 return [s for s in self.sampler_list if s.beta < 1]
704 @property
705 def zerotemp_sampler_list(self):
706 return [s for s in self.sampler_list if s.beta == 1]
708 @property
709 def primary_sampler(self):
710 return self.sampler_dictionary[0][0]
712 def set_pt_inputs(self, pt_inputs):
713 logger.info(f"Setting parallel tempering inputs={pt_inputs}")
714 self.pt_inputs = pt_inputs
716 # Pull out only what is needed
717 self.ntemps = pt_inputs.ntemps
718 self.nensemble = pt_inputs.nensemble
719 self.pt_ensemble = pt_inputs.pt_ensemble
720 self.adapt = pt_inputs.adapt
721 self.adapt_t0 = pt_inputs.adapt_t0
722 self.adapt_nu = pt_inputs.adapt_nu
724 def set_convergence_inputs(self, convergence_inputs):
725 logger.info(f"Setting convergence_inputs={convergence_inputs}")
726 self.convergence_inputs = convergence_inputs
727 self.L1steps = convergence_inputs.L1steps
728 self.L2steps = convergence_inputs.L2steps
729 for sampler in self.sampler_list:
730 sampler.set_convergence_inputs(convergence_inputs)
732 @property
733 def tau(self):
734 return self.primary_sampler.chain.tau
736 @property
737 def minimum_index(self):
738 return self.primary_sampler.chain.minimum_index
740 @property
741 def nsamples(self):
742 pos = self.primary_sampler.chain.position
743 if hasattr(self, "_nsamples_dict") is False:
744 self._nsamples_dict = {}
745 if pos in self._nsamples_dict:
746 return self._nsamples_dict[pos]
747 logger.debug(f"Calculating nsamples at {pos}")
748 self._nsamples_dict[pos] = self._calculate_nsamples()
749 return self._nsamples_dict[pos]
751 @property
752 def nsamples_last(self):
753 if len(self._nsamples_dict) > 0:
754 return list(self._nsamples_dict.values())[-1]
755 else:
756 return 0
758 @property
759 def nsamples_nocache(self):
760 for sampler in self.sampler_list:
761 sampler.chain.tau_nocache
762 pos = self.primary_sampler.chain.position
763 self._nsamples_dict[pos] = self._calculate_nsamples()
764 return self._nsamples_dict[pos]
766 def _calculate_nsamples(self):
767 nsamples_list = []
768 for sampler in self.zerotemp_sampler_list:
769 nsamples_list.append(sampler.nsamples)
770 if self.pt_rejection_sample:
771 for samp in self.sampler_list[1:]:
772 nsamples_list.append(
773 len(samp.rejection_sample_zero_temperature_samples())
774 )
775 return sum(nsamples_list)
777 @property
778 def samples(self):
779 cached_samples = getattr(self, "_cached_samples", (False,))
780 if cached_samples[0] == self.position:
781 return cached_samples[1]
783 sample_list = []
784 for sampler in self.zerotemp_sampler_list:
785 sample_list.append(sampler.samples)
786 if self.pt_rejection_sample:
787 for sampler in self.tempered_sampler_list:
788 sample_list.append(sampler.samples)
789 samples = pd.concat(sample_list, ignore_index=True)
790 self._cached_samples = (self.position, samples)
791 return samples
793 @property
794 def position(self):
795 return self.primary_sampler.chain.position
797 @property
798 def evaluations(self):
799 return int(self.position * len(self.sampler_list))
801 def __getitem__(self, index):
802 return self.sampler_list[index]
804 def step_all_chains(self):
805 if self.pool:
806 self.sampler_list = self.pool.map(call_step, self.sampler_list)
807 else:
808 for ii, sampler in enumerate(self.sampler_list):
809 self.sampler_list[ii] = sampler.step()
811 if self.nensemble > 1 and self.swap_counter["L2-ensemble"] >= self.L2steps:
812 self.swap_counter["ensemble"] += 1
813 self.swap_counter["L2-ensemble"] = 0
814 self.ensemble_step()
816 if self.ntemps > 1 and self.swap_counter["L2-temperature"] >= self.L2steps:
817 self.swap_counter["temperature"] += 1
818 self.swap_counter["L2-temperature"] = 0
819 self.swap_tempered_chains()
820 if self.position < self.adapt_t0 * 10:
821 if self.adapt:
822 self.adapt_temperatures()
823 elif self.adapt:
824 logger.info(
825 f"Adaptation of temperature chains finished at step {self.position}"
826 )
827 self.adapt = False
829 self.swap_counter["L2-ensemble"] += 1
830 self.swap_counter["L2-temperature"] += 1
832 @staticmethod
833 def _get_sample_to_swap(sampler):
834 if not (sampler.chain.converged and sampler.stop_after_convergence):
835 v = sampler.chain[-1]
836 else:
837 v = sampler.chain.random_sample
838 logl = v[LOGLKEY]
839 return v, logl
841 def swap_tempered_chains(self):
842 if self.pt_ensemble:
843 Eindexs = range(self.nensemble)
844 else:
845 Eindexs = [0]
846 for Eindex in Eindexs:
847 for Tindex in range(self.ntemps - 1):
848 sampleri = self.sampler_dictionary[Tindex][Eindex]
849 vi, logli = self._get_sample_to_swap(sampleri)
850 betai = sampleri.beta
852 samplerj = self.sampler_dictionary[Tindex + 1][Eindex]
853 vj, loglj = self._get_sample_to_swap(samplerj)
854 betaj = samplerj.beta
856 dbeta = betaj - betai
857 with np.errstate(over="ignore"):
858 alpha_swap = np.exp(dbeta * (logli - loglj))
860 if random.rng.uniform(0, 1) <= alpha_swap:
861 sampleri.chain[-1] = vj
862 samplerj.chain[-1] = vi
863 self.sampler_dictionary[Tindex][Eindex] = sampleri
864 self.sampler_dictionary[Tindex + 1][Eindex] = samplerj
865 sampleri.pt_accepted += 1
866 else:
867 sampleri.pt_rejected += 1
869 def ensemble_step(self):
870 for Tindex, sampler_list in self.sampler_dictionary.items():
871 if len(sampler_list) > 1:
872 for Eindex, sampler in enumerate(sampler_list):
873 curr = sampler.chain.current_sample
874 proposal = self.ensemble_proposal_cycle.get_proposal()
875 complement = [s.chain for s in sampler_list if s != sampler]
876 prop, log_factor = proposal(sampler.chain, complement)
877 logp = sampler.log_prior(prop)
879 if logp == -np.inf:
880 sampler.reject_proposal(curr, proposal)
881 self.sampler_dictionary[Tindex][Eindex] = sampler
882 continue
884 prop[LOGPKEY] = logp
885 prop[LOGLKEY] = sampler.log_likelihood(prop)
886 alpha = np.exp(
887 log_factor
888 + sampler.beta * prop[LOGLKEY]
889 + prop[LOGPKEY]
890 - sampler.beta * curr[LOGLKEY]
891 - curr[LOGPKEY]
892 )
894 if random.rng.uniform(0, 1) <= alpha:
895 sampler.accept_proposal(prop, proposal)
896 else:
897 sampler.reject_proposal(curr, proposal)
898 self.sampler_dictionary[Tindex][Eindex] = sampler
900 def adapt_temperatures(self):
901 """Adapt the temperature of the chains
903 Using the dynamic temperature selection described in arXiv:1501.05823,
904 adapt the chains to target a constant swap ratio. This method is based
905 on github.com/willvousden/ptemcee/tree/master/ptemcee
906 """
908 self.primary_sampler.chain.minimum_index_adapt = self.position
909 tt = self.swap_counter["temperature"]
910 for sampler_list in self.sampler_list_of_tempered_lists:
911 betas = np.array([s.beta for s in sampler_list])
912 ratios = np.array([s.acceptance_ratio for s in sampler_list[:-1]])
914 # Modulate temperature adjustments with a hyperbolic decay.
915 decay = self.adapt_t0 / (tt + self.adapt_t0)
916 kappa = decay / self.adapt_nu
918 # Construct temperature adjustments.
919 dSs = kappa * (ratios[:-1] - ratios[1:])
921 # Compute new ladder (hottest and coldest chains don't move).
922 deltaTs = np.diff(1 / betas[:-1])
923 deltaTs *= np.exp(dSs)
924 betas[1:-1] = 1 / (np.cumsum(deltaTs) + 1 / betas[0])
925 for sampler, beta in zip(sampler_list, betas):
926 sampler.beta = beta
928 @property
929 def ln_z(self):
930 return self.ln_z_dict.get(self.evidence_method, np.nan)
932 @property
933 def ln_z_err(self):
934 return self.ln_z_err_dict.get(self.evidence_method, np.nan)
936 def compute_evidence(self, outdir, label, make_plots=True):
937 if self.ntemps == 1:
938 return
939 kwargs = dict(outdir=outdir, label=label, make_plots=make_plots)
940 methods = dict(
941 thermodynamic=self.thermodynamic_integration_evidence,
942 stepping_stone=self.stepping_stone_evidence,
943 )
944 for key, method in methods.items():
945 ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs)
946 self.ln_z_dict[key] = ln_z
947 self.ln_z_err_dict[key] = ln_z_err
948 logger.debug(
949 f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method"
950 )
952 def compute_evidence_per_ensemble(self, method, kwargs):
953 from scipy.special import logsumexp
955 if self.ntemps == 1:
956 return np.nan, np.nan
958 lnZ_list = []
959 lnZerr_list = []
960 for index, ptchain in enumerate(self.sampler_list_of_tempered_lists):
961 lnZ, lnZerr = method(ptchain, **kwargs)
962 lnZ_list.append(lnZ)
963 lnZerr_list.append(lnZerr)
965 N = len(lnZ_list)
967 # Average lnZ
968 lnZ = logsumexp(lnZ_list, b=1.0 / N)
970 # Propagate uncertainty in combined evidence
971 lnZerr = 0.5 * logsumexp(2 * np.array(lnZerr_list), b=1.0 / N)
973 return lnZ, lnZerr
975 def thermodynamic_integration_evidence(
976 self, ptchain, outdir, label, make_plots=True
977 ):
978 """Computes the evidence using thermodynamic integration
980 We compute the evidence without the burnin samples, no thinning
981 """
982 from scipy.stats import sem
984 betas = []
985 mean_lnlikes = []
986 sem_lnlikes = []
987 for sampler in ptchain:
988 lnlikes = sampler.chain.get_1d_array(LOGLKEY)
989 mindex = sampler.chain.minimum_index
990 lnlikes = lnlikes[mindex:]
991 mean_lnlikes.append(np.mean(lnlikes))
992 sem_lnlikes.append(sem(lnlikes))
993 betas.append(sampler.beta)
995 # Convert to array and re-order
996 betas = np.array(betas)[::-1]
997 mean_lnlikes = np.array(mean_lnlikes)[::-1]
998 sem_lnlikes = np.array(sem_lnlikes)[::-1]
1000 lnZ, lnZerr = self._compute_evidence_from_mean_lnlikes(betas, mean_lnlikes)
1002 if make_plots:
1003 plot_label = f"{label}_E{ptchain[0].Eindex}"
1004 self._create_lnZ_plots(
1005 betas=betas,
1006 mean_lnlikes=mean_lnlikes,
1007 outdir=outdir,
1008 label=plot_label,
1009 sem_lnlikes=sem_lnlikes,
1010 )
1012 return lnZ, lnZerr
1014 def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True):
1015 """
1016 Compute the evidence using the stepping stone approximation.
1018 See https://arxiv.org/abs/1810.04488 and
1019 https://pubmed.ncbi.nlm.nih.gov/21187451/ for details.
1021 The uncertainty calculation is hopefully combining the evidence in each
1022 of the steps.
1024 Returns
1025 -------
1026 ln_z: float
1027 Estimate of the natural log evidence
1028 ln_z_err: float
1029 Estimate of the uncertainty in the evidence
1030 """
1031 # Order in increasing beta
1032 ptchain.reverse()
1034 # Get maximum usable set of samples across the ptchain
1035 min_index = max([samp.chain.minimum_index for samp in ptchain])
1036 max_index = min([len(samp.chain.get_1d_array(LOGLKEY)) for samp in ptchain])
1037 tau = self.tau
1039 if max_index - min_index <= 1 or np.isinf(tau):
1040 return np.nan, np.nan
1042 # Read in log likelihoods
1043 ln_likes = np.array(
1044 [samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain]
1045 )[:-1].T
1047 # Thin to only independent samples
1048 ln_likes = ln_likes[:: int(self.tau), :]
1049 steps = ln_likes.shape[0]
1051 # Calculate delta betas
1052 betas = np.array([samp.beta for samp in ptchain])
1054 ln_z, ln_ratio = self._calculate_stepping_stone(betas, ln_likes)
1056 # Implementation of the bootstrap method described in Maturana-Russel
1057 # et. al. (2019) to estimate the evidence uncertainty.
1058 ll = 50 # Block length
1059 repeats = 100 # Repeats
1060 ln_z_realisations = []
1061 try:
1062 for _ in range(repeats):
1063 idxs = [random.rng.integers(i, i + ll) for i in range(steps - ll)]
1064 ln_z_realisations.append(
1065 self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0]
1066 )
1067 ln_z_err = np.std(ln_z_realisations)
1068 except ValueError:
1069 logger.info("Failed to estimate stepping stone uncertainty")
1070 ln_z_err = np.nan
1072 if make_plots:
1073 plot_label = f"{label}_E{ptchain[0].Eindex}"
1074 self._create_stepping_stone_plot(
1075 means=ln_ratio,
1076 outdir=outdir,
1077 label=plot_label,
1078 )
1080 return ln_z, ln_z_err
1082 @staticmethod
1083 def _calculate_stepping_stone(betas, ln_likes):
1084 from scipy.special import logsumexp
1086 n_samples = ln_likes.shape[0]
1087 d_betas = betas[1:] - betas[:-1]
1088 ln_ratio = logsumexp(d_betas * ln_likes, axis=0) - np.log(n_samples)
1089 return sum(ln_ratio), ln_ratio
1091 @staticmethod
1092 def _compute_evidence_from_mean_lnlikes(betas, mean_lnlikes):
1093 lnZ = np.trapz(mean_lnlikes, betas)
1094 z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1])
1095 lnZerr = np.abs(lnZ - z2)
1096 return lnZ, lnZerr
1098 def _create_lnZ_plots(self, betas, mean_lnlikes, outdir, label, sem_lnlikes=None):
1099 import matplotlib.pyplot as plt
1101 logger.debug("Creating thermodynamic evidence diagnostic plot")
1103 fig, ax1 = plt.subplots()
1104 if betas[-1] == 0:
1105 x, y = betas[:-1], mean_lnlikes[:-1]
1106 else:
1107 x, y = betas, mean_lnlikes
1108 if sem_lnlikes is not None:
1109 ax1.errorbar(x, y, sem_lnlikes, fmt="-")
1110 else:
1111 ax1.plot(x, y, "-o")
1112 ax1.set_xscale("log")
1113 ax1.set_xlabel(r"$\beta$")
1114 ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
1116 plt.tight_layout()
1117 fig.savefig("{}/{}_beta_lnl.png".format(outdir, label))
1118 plt.close()
1120 def _create_stepping_stone_plot(self, means, outdir, label):
1121 import matplotlib.pyplot as plt
1123 logger.debug("Creating stepping stone evidence diagnostic plot")
1125 n_steps = len(means)
1127 fig, axes = plt.subplots(nrows=2, figsize=(8, 10))
1129 ax = axes[0]
1130 ax.plot(np.arange(1, n_steps + 1), means)
1131 ax.set_xlabel("$k$")
1132 ax.set_ylabel("$r_{k}$")
1134 ax = axes[1]
1135 ax.plot(np.arange(1, n_steps + 1), np.cumsum(means[::1])[::1])
1136 ax.set_xlabel("$k$")
1137 ax.set_ylabel("Cumulative $\\ln Z$")
1139 plt.tight_layout()
1140 fig.savefig("{}/{}_stepping_stone.png".format(outdir, label))
1141 plt.close()
1143 @property
1144 def rejection_sampling_count(self):
1145 if self.pt_rejection_sample:
1146 counts = 0
1147 for column in self.sampler_list_of_tempered_lists:
1148 for sampler in column:
1149 counts += sampler.rejection_sampling_count
1150 return counts
1151 else:
1152 return None
1155class BilbyMCMCSampler(object):
1156 def __init__(
1157 self,
1158 convergence_inputs,
1159 proposal_cycle=None,
1160 beta=1,
1161 Tindex=0,
1162 Eindex=0,
1163 use_ratio=False,
1164 initial_sample_method="prior",
1165 initial_sample_dict=None,
1166 normalize_prior=True,
1167 ):
1168 self.beta = beta
1169 self.Tindex = Tindex
1170 self.Eindex = Eindex
1171 self.use_ratio = use_ratio
1172 self.normalize_prior = normalize_prior
1173 self.parameters = _sampling_convenience_dump.priors.non_fixed_keys
1174 self.ndim = len(self.parameters)
1176 if initial_sample_method.lower() == "prior":
1177 full_sample_dict = _sampling_convenience_dump.priors.sample()
1178 initial_sample = {
1179 k: v
1180 for k, v in full_sample_dict.items()
1181 if k in _sampling_convenience_dump.priors.non_fixed_keys
1182 }
1183 elif initial_sample_method.lower() in ["maximize", "maximise", "maximum"]:
1184 initial_sample = get_initial_maximimum_posterior_sample(self.beta)
1185 else:
1186 raise ValueError(
1187 f"initial sample method {initial_sample_method} not understood"
1188 )
1190 if initial_sample_dict is not None:
1191 initial_sample.update(initial_sample_dict)
1193 if self.beta == 1:
1194 logger.info(f"Using initial sample {initial_sample}")
1196 initial_sample = Sample(initial_sample)
1197 initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
1198 initial_sample[LOGPKEY] = self.log_prior(initial_sample)
1200 self.chain = Chain(initial_sample=initial_sample)
1201 self.set_convergence_inputs(convergence_inputs)
1203 self.accepted = 0
1204 self.rejected = 0
1205 self.pt_accepted = 0
1206 self.pt_rejected = 0
1207 self.rejection_sampling_count = 0
1209 if isinstance(proposal_cycle, str):
1210 # Only print warnings for the primary sampler
1211 if Tindex == 0 and Eindex == 0:
1212 warn = True
1213 else:
1214 warn = False
1216 self.proposal_cycle = proposals.get_proposal_cycle(
1217 proposal_cycle,
1218 _sampling_convenience_dump.priors,
1219 L1steps=self.chain.L1steps,
1220 warn=warn,
1221 )
1222 elif isinstance(proposal_cycle, proposals.ProposalCycle):
1223 self.proposal_cycle = proposal_cycle
1224 else:
1225 raise SamplerError("Proposal cycle not understood")
1227 if self.Tindex == 0 and self.Eindex == 0:
1228 logger.info(f"Using {self.proposal_cycle}")
1230 def set_convergence_inputs(self, convergence_inputs):
1231 for key, val in convergence_inputs._asdict().items():
1232 setattr(self.chain, key, val)
1233 self.target_nsamples = convergence_inputs.target_nsamples
1234 self.stop_after_convergence = convergence_inputs.stop_after_convergence
1236 def log_likelihood(self, sample):
1237 _sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict)
1239 if self.use_ratio:
1240 logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio()
1241 else:
1242 logl = _sampling_convenience_dump.likelihood.log_likelihood()
1244 return logl
1246 def log_prior(self, sample):
1247 return _sampling_convenience_dump.priors.ln_prob(
1248 sample.parameter_only_dict,
1249 normalized=self.normalize_prior,
1250 )
1252 def accept_proposal(self, prop, proposal):
1253 self.chain.append(prop)
1254 self.accepted += 1
1255 proposal.accepted += 1
1257 def reject_proposal(self, curr, proposal):
1258 self.chain.append(curr)
1259 self.rejected += 1
1260 proposal.rejected += 1
1262 def step(self):
1263 if self.stop_after_convergence and self.chain.converged:
1264 return self
1266 internal_steps = 0
1267 internal_accepted = 0
1268 internal_rejected = 0
1269 curr = self.chain.current_sample.copy()
1270 while internal_steps < self.chain.L1steps:
1271 internal_steps += 1
1272 proposal = self.proposal_cycle.get_proposal()
1273 prop, log_factor = proposal(
1274 self.chain,
1275 likelihood=_sampling_convenience_dump.likelihood,
1276 priors=_sampling_convenience_dump.priors,
1277 )
1278 logp = self.log_prior(prop)
1280 if np.isinf(logp) or np.isnan(logp):
1281 internal_rejected += 1
1282 proposal.rejected += 1
1283 continue
1285 prop[LOGPKEY] = logp
1286 prop[LOGLKEY] = self.log_likelihood(prop)
1288 if np.isinf(prop[LOGLKEY]) or np.isnan(prop[LOGLKEY]):
1289 internal_rejected += 1
1290 proposal.rejected += 1
1291 continue
1293 with np.errstate(over="ignore"):
1294 alpha = np.exp(
1295 log_factor
1296 + self.beta * prop[LOGLKEY]
1297 + prop[LOGPKEY]
1298 - self.beta * curr[LOGLKEY]
1299 - curr[LOGPKEY]
1300 )
1302 if random.rng.uniform(0, 1) <= alpha:
1303 internal_accepted += 1
1304 proposal.accepted += 1
1305 curr = prop
1306 self.chain.current_sample = curr
1307 else:
1308 internal_rejected += 1
1309 proposal.rejected += 1
1311 self.chain.append(curr)
1312 self.rejected += internal_rejected
1313 self.accepted += internal_accepted
1314 return self
1316 @property
1317 def nsamples(self):
1318 nsamples = self.chain.nsamples
1319 if nsamples > self.target_nsamples and self.chain.converged is False:
1320 logger.debug(f"Temperature {self.Tindex} chain reached convergence")
1321 self.chain.converged = True
1322 return nsamples
1324 @property
1325 def acceptance_ratio(self):
1326 return self.accepted / (self.accepted + self.rejected)
1328 @property
1329 def samples(self):
1330 if self.beta == 1:
1331 return self.chain.samples
1332 else:
1333 return self.rejection_sample_zero_temperature_samples(print_message=True)
1335 def rejection_sample_zero_temperature_samples(self, print_message=False):
1336 beta = self.beta
1337 chain = self.chain
1338 hot_samples = pd.DataFrame(
1339 chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys
1340 )
1341 if len(hot_samples) == 0:
1342 logger.debug(
1343 f"Rejection sampling for Temp {self.Tindex} failed: "
1344 "no usable hot samples"
1345 )
1346 return hot_samples
1348 # Pull out log likelihood
1349 zerotemp_logl = hot_samples[LOGLKEY]
1351 # Revert to true likelihood if needed
1352 if _sampling_convenience_dump.use_ratio:
1353 zerotemp_logl += (
1354 _sampling_convenience_dump.likelihood.noise_log_likelihood()
1355 )
1357 # Calculate normalised weights
1358 log_weights = (1 - beta) * zerotemp_logl
1359 max_weight = np.max(log_weights)
1360 unnormalised_weights = np.exp(log_weights - max_weight)
1361 weights = unnormalised_weights / np.sum(unnormalised_weights)
1363 # Rejection sample
1364 samples = rejection_sample(hot_samples, weights)
1366 # Logging
1367 self.rejection_sampling_count = len(samples)
1369 if print_message:
1370 logger.info(
1371 f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} "
1372 f"yielded {len(samples)} samples"
1373 )
1374 return samples
1377def get_initial_maximimum_posterior_sample(beta):
1378 """A method to attempt optimization of the maximum likelihood
1380 This uses a simple scipy optimization approach, starting from a number
1381 of draws from the prior to avoid problems with local optimization.
1383 """
1384 logger.info("Finding initial maximum posterior estimate")
1385 likelihood = _sampling_convenience_dump.likelihood
1386 priors = _sampling_convenience_dump.priors
1387 search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
1389 bounds = []
1390 for key in search_parameter_keys:
1391 bounds.append((priors[key].minimum, priors[key].maximum))
1393 def neg_log_post(x):
1394 sample = {key: val for key, val in zip(search_parameter_keys, x)}
1395 ln_prior = priors.ln_prob(sample)
1397 if np.isinf(ln_prior):
1398 return -np.inf
1400 likelihood.parameters.update(sample)
1402 return -beta * likelihood.log_likelihood() - ln_prior
1404 res = differential_evolution(neg_log_post, bounds, popsize=100, init="sobol")
1405 if res.success:
1406 sample = {key: val for key, val in zip(search_parameter_keys, res.x)}
1407 logger.info(f"Initial maximum posterior estimate {sample}")
1408 return sample
1409 else:
1410 raise ValueError("Failed to find initial maximum posterior estimate")
1413# Methods used to aid parallelisation:
1416def call_step(sampler):
1417 sampler = sampler.step()
1418 return sampler