Coverage for bilby/core/sampler/zeus.py: 37%
81 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import os
2import shutil
3from shutil import copyfile
5import numpy as np
7from .base_sampler import SamplerError, signal_wrapper
8from .emcee import Emcee
9from .ptemcee import LikePriorEvaluator
11_evaluator = LikePriorEvaluator()
14class Zeus(Emcee):
15 """bilby wrapper for Zeus (https://zeus-mcmc.readthedocs.io/)
17 All positional and keyword arguments (i.e., the args and kwargs) passed to
18 `run_sampler` will be propagated to `zeus.EnsembleSampler`, see
19 documentation for that class for further help. Under Other Parameters, we
20 list commonly used kwargs and the bilby defaults.
22 Parameters
23 ==========
24 nwalkers: int, (500)
25 The number of walkers
26 nsteps: int, (100)
27 The number of steps
28 nburn: int (None)
29 If given, the fixed number of steps to discard as burn-in. These will
30 be discarded from the total number of steps set by `nsteps` and
31 therefore the value must be greater than `nsteps`. Else, nburn is
32 estimated from the autocorrelation time
33 burn_in_fraction: float, (0.25)
34 The fraction of steps to discard as burn-in in the event that the
35 autocorrelation time cannot be calculated
36 burn_in_act: float
37 The number of autocorrelation times to discard as burn-in
39 """
41 sampler_name = "zeus"
42 default_kwargs = dict(
43 nwalkers=500,
44 args=[],
45 kwargs={},
46 pool=None,
47 log_prob0=None,
48 start=None,
49 blobs0=None,
50 iterations=100,
51 thin=1,
52 )
54 def __init__(
55 self,
56 likelihood,
57 priors,
58 outdir="outdir",
59 label="label",
60 use_ratio=False,
61 plot=False,
62 skip_import_verification=False,
63 pos0=None,
64 nburn=None,
65 burn_in_fraction=0.25,
66 resume=True,
67 burn_in_act=3,
68 **kwargs,
69 ):
70 super(Zeus, self).__init__(
71 likelihood=likelihood,
72 priors=priors,
73 outdir=outdir,
74 label=label,
75 use_ratio=use_ratio,
76 plot=plot,
77 skip_import_verification=skip_import_verification,
78 pos0=pos0,
79 nburn=nburn,
80 burn_in_fraction=burn_in_fraction,
81 resume=resume,
82 burn_in_act=burn_in_act,
83 **kwargs,
84 )
86 def _translate_kwargs(self, kwargs):
87 super(Zeus, self)._translate_kwargs(kwargs=kwargs)
89 # check if using emcee-style arguments
90 if "start" not in kwargs:
91 if "rstate0" in kwargs:
92 kwargs["start"] = kwargs.pop("rstate0")
93 if "log_prob0" not in kwargs:
94 if "lnprob0" in kwargs:
95 kwargs["log_prob0"] = kwargs.pop("lnprob0")
97 @property
98 def sampler_function_kwargs(self):
99 keys = ["log_prob0", "start", "blobs0", "iterations", "thin", "progress"]
101 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
103 return function_kwargs
105 @property
106 def sampler_init_kwargs(self):
107 init_kwargs = {
108 key: value
109 for key, value in self.kwargs.items()
110 if key not in self.sampler_function_kwargs
111 }
113 init_kwargs["logprob_fn"] = _evaluator.call_emcee
114 init_kwargs["ndim"] = self.ndim
116 return init_kwargs
118 def write_current_state(self):
119 self._sampler.distribute = map
120 super(Zeus, self).write_current_state()
121 self._sampler.distribute = getattr(self._sampler.pool, "map", map)
123 def _initialise_sampler(self):
124 from zeus import EnsembleSampler
126 self._sampler = EnsembleSampler(**self.sampler_init_kwargs)
127 self._init_chain_file()
129 def write_chains_to_file(self, sample):
130 chain_file = self.checkpoint_info.chain_file
131 temp_chain_file = chain_file + ".temp"
132 if os.path.isfile(chain_file):
133 copyfile(chain_file, temp_chain_file)
135 points = np.hstack([sample[0], np.array(sample[2])])
137 with open(temp_chain_file, "a") as ff:
138 for ii, point in enumerate(points):
139 ff.write(self.checkpoint_info.chain_template.format(ii, *point))
140 shutil.move(temp_chain_file, chain_file)
142 def _set_pos0_for_resume(self):
143 self.pos0 = self.sampler.get_last_sample()
145 @signal_wrapper
146 def run_sampler(self):
147 self._setup_pool()
148 sampler_function_kwargs = self.sampler_function_kwargs
149 iterations = sampler_function_kwargs.pop("iterations")
150 iterations -= self._previous_iterations
152 sampler_function_kwargs["start"] = self.pos0
154 # main iteration loop
155 for sample in self.sampler.sample(
156 iterations=iterations, **sampler_function_kwargs
157 ):
158 self.write_chains_to_file(sample)
159 self._close_pool()
160 self.write_current_state()
162 self.result.sampler_output = np.nan
163 self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
164 self.print_nburn_logging_info()
166 self._generate_result()
168 self.result.samples = self.sampler.get_chain(flat=True, discard=self.nburn)
169 self.result.walkers = self.sampler.chain
170 return self.result
172 def _generate_result(self):
173 self.result.nburn = self.nburn
174 self.calc_likelihood_count()
175 if self.result.nburn > self.nsteps:
176 raise SamplerError(
177 "The run has finished, but the chain is not burned in: "
178 f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})."
179 " Try increasing the number of steps."
180 )
181 blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape(
182 (-1, 2)
183 )
184 log_likelihoods, log_priors = blobs.T
185 self.result.log_likelihood_evaluations = log_likelihoods
186 self.result.log_prior_evaluations = log_priors
187 self.result.log_evidence = np.nan
188 self.result.log_evidence_err = np.nan