Coverage for bilby/core/sampler/emcee.py: 35%
220 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import os
2import shutil
3from collections import namedtuple
5import numpy as np
6from packaging import version
7from pandas import DataFrame
9from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
10from .base_sampler import MCMCSampler, SamplerError, signal_wrapper
11from .ptemcee import LikePriorEvaluator
13_evaluator = LikePriorEvaluator()
16class Emcee(MCMCSampler):
17 """bilby wrapper emcee (https://github.com/dfm/emcee)
19 All positional and keyword arguments (i.e., the args and kwargs) passed to
20 `run_sampler` will be propagated to `emcee.EnsembleSampler`, see
21 documentation for that class for further help. Under Other Parameters, we
22 list commonly used kwargs and the bilby defaults.
24 Parameters
25 ==========
26 nwalkers: int, (500)
27 The number of walkers
28 nsteps: int, (100)
29 The number of steps
30 nburn: int (None)
31 If given, the fixed number of steps to discard as burn-in. These will
32 be discarded from the total number of steps set by `nsteps` and
33 therefore the value must be greater than `nsteps`. Else, nburn is
34 estimated from the autocorrelation time
35 burn_in_fraction: float, (0.25)
36 The fraction of steps to discard as burn-in in the event that the
37 autocorrelation time cannot be calculated
38 burn_in_act: float
39 The number of autocorrelation times to discard as burn-in
40 a: float (2)
41 The proposal scale factor
42 verbose: bool
43 Whether to print diagnostic information during the analysis
45 """
47 sampler_name = "emcee"
48 default_kwargs = dict(
49 nwalkers=500,
50 a=2,
51 args=[],
52 kwargs={},
53 postargs=None,
54 pool=None,
55 live_dangerously=False,
56 runtime_sortingfn=None,
57 lnprob0=None,
58 rstate0=None,
59 blobs0=None,
60 iterations=100,
61 thin=1,
62 storechain=True,
63 mh_proposal=None,
64 )
66 def __init__(
67 self,
68 likelihood,
69 priors,
70 outdir="outdir",
71 label="label",
72 use_ratio=False,
73 plot=False,
74 skip_import_verification=False,
75 pos0=None,
76 nburn=None,
77 burn_in_fraction=0.25,
78 resume=True,
79 burn_in_act=3,
80 **kwargs,
81 ):
82 self._check_version()
83 super(Emcee, self).__init__(
84 likelihood=likelihood,
85 priors=priors,
86 outdir=outdir,
87 label=label,
88 use_ratio=use_ratio,
89 plot=plot,
90 skip_import_verification=skip_import_verification,
91 **kwargs,
92 )
93 self.resume = resume
94 self.pos0 = pos0
95 self.nburn = nburn
96 self.burn_in_fraction = burn_in_fraction
97 self.burn_in_act = burn_in_act
98 self.verbose = kwargs.get("verbose", True)
100 def _check_version(self):
101 import emcee
103 if version.parse(emcee.__version__) < version.parse("3"):
104 self.prerelease = False
105 else:
106 self.prerelease = True
108 def _translate_kwargs(self, kwargs):
109 kwargs = super()._translate_kwargs(kwargs)
110 if "nwalkers" not in kwargs:
111 for equiv in self.nwalkers_equiv_kwargs:
112 if equiv in kwargs:
113 kwargs["nwalkers"] = kwargs.pop(equiv)
114 if "iterations" not in kwargs:
115 if "nsteps" in kwargs:
116 kwargs["iterations"] = kwargs.pop("nsteps")
118 @property
119 def sampler_function_kwargs(self):
120 keys = [
121 "lnprob0",
122 "rstate0",
123 "blobs0",
124 "iterations",
125 "thin",
126 "storechain",
127 "mh_proposal",
128 ]
130 # updated function keywords for emcee > v2.2.1
131 updatekeys = {
132 "p0": "initial_state",
133 "lnprob0": "log_prob0",
134 "storechain": "store",
135 }
137 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
138 function_kwargs["p0"] = self.pos0
140 if self.prerelease:
141 if function_kwargs["mh_proposal"] is not None:
142 logger.warning(
143 "The 'mh_proposal' option is no longer used "
144 "in emcee > 2.2.1, and will be ignored."
145 )
146 del function_kwargs["mh_proposal"]
148 for key in updatekeys:
149 if updatekeys[key] not in function_kwargs:
150 function_kwargs[updatekeys[key]] = function_kwargs.pop(key)
151 else:
152 del function_kwargs[key]
154 return function_kwargs
156 @property
157 def sampler_init_kwargs(self):
158 init_kwargs = {
159 key: value
160 for key, value in self.kwargs.items()
161 if key not in self.sampler_function_kwargs
162 }
164 init_kwargs["lnpostfn"] = _evaluator.call_emcee
165 init_kwargs["dim"] = self.ndim
167 # updated init keywords for emcee > v2.2.1
168 updatekeys = {"dim": "ndim", "lnpostfn": "log_prob_fn"}
170 if self.prerelease:
171 for key in updatekeys:
172 if key in init_kwargs:
173 init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
175 oldfunckeys = ["p0", "lnprob0", "storechain", "mh_proposal"]
176 for key in oldfunckeys:
177 if key in init_kwargs:
178 del init_kwargs[key]
180 return init_kwargs
182 @property
183 def nburn(self):
184 if type(self.__nburn) in [float, int]:
185 return int(self.__nburn)
186 elif self.result.max_autocorrelation_time is None:
187 return int(self.burn_in_fraction * self.nsteps)
188 else:
189 return int(self.burn_in_act * self.result.max_autocorrelation_time)
191 @nburn.setter
192 def nburn(self, nburn):
193 if isinstance(nburn, (float, int)):
194 if nburn > self.kwargs["iterations"] - 1:
195 raise ValueError(
196 "Number of burn-in samples must be smaller "
197 "than the total number of iterations"
198 )
200 self.__nburn = nburn
202 @property
203 def nwalkers(self):
204 return self.kwargs["nwalkers"]
206 @property
207 def nsteps(self):
208 return self.kwargs["iterations"]
210 @nsteps.setter
211 def nsteps(self, nsteps):
212 self.kwargs["iterations"] = nsteps
214 @property
215 def stored_chain(self):
216 """Read the stored zero-temperature chain data in from disk"""
217 return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
219 @property
220 def stored_samples(self):
221 """Returns the samples stored on disk"""
222 return self.stored_chain[self.search_parameter_keys]
224 @property
225 def stored_loglike(self):
226 """Returns the log-likelihood stored on disk"""
227 return self.stored_chain["log_l"]
229 @property
230 def stored_logprior(self):
231 """Returns the log-prior stored on disk"""
232 return self.stored_chain["log_p"]
234 def _init_chain_file(self):
235 with open(self.checkpoint_info.chain_file, "w+") as ff:
236 search_keys_str = "\t".join(self.search_parameter_keys)
237 ff.write(f"walker\t{search_keys_str}\tlog_l\tlog_p\n")
239 @property
240 def checkpoint_info(self):
241 """Defines various things related to checkpointing and storing data
243 Returns
244 =======
245 checkpoint_info: named_tuple
246 An object with attributes `sampler_file`, `chain_file`, and
247 `chain_template`. The first two give paths to where the sampler and
248 chain data is stored, the last a formatted-str-template with which
249 to write the chain data to disk
251 """
252 out_dir = os.path.join(
253 self.outdir, f"{self.__class__.__name__.lower()}_{self.label}"
254 )
255 check_directory_exists_and_if_not_mkdir(out_dir)
257 chain_file = os.path.join(out_dir, "chain.dat")
258 sampler_file = os.path.join(out_dir, "sampler.pickle")
259 chain_template = (
260 "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n"
261 )
263 CheckpointInfo = namedtuple(
264 "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"]
265 )
267 checkpoint_info = CheckpointInfo(
268 sampler_file=sampler_file,
269 chain_file=chain_file,
270 chain_template=chain_template,
271 )
273 return checkpoint_info
275 @property
276 def sampler_chain(self):
277 nsteps = self._previous_iterations
278 return self.sampler.chain[:, :nsteps, :]
280 def write_current_state(self):
281 """
282 Writes a pickle file of the sampler to disk using dill
284 Overwrites the stored sampler chain with one that is truncated
285 to only the completed steps
286 """
287 logger.info(
288 f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}"
289 )
290 self.sampler._chain = self.sampler_chain
291 _pool = self.sampler.pool
292 self.sampler.pool = None
293 safe_file_dump(self._sampler, self.checkpoint_info.sampler_file, "dill")
294 self.sampler.pool = _pool
296 def _initialise_sampler(self):
297 from emcee import EnsembleSampler
299 self._sampler = EnsembleSampler(**self.sampler_init_kwargs)
300 self._init_chain_file()
302 @property
303 def sampler(self):
304 """Returns the emcee sampler object
306 If, already initialized, returns the stored _sampler value. Otherwise,
307 first checks if there is a pickle file from which to load. If there is
308 not, then initialize the sampler and set the initial random draw
310 """
311 if hasattr(self, "_sampler"):
312 pass
313 elif (
314 self.resume
315 and os.path.isfile(self.checkpoint_info.sampler_file)
316 and os.path.getsize(self.checkpoint_info.sampler_file)
317 ):
318 import dill
320 logger.info(
321 f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}"
322 )
323 with open(self.checkpoint_info.sampler_file, "rb") as f:
324 self._sampler = dill.load(f)
325 self._sampler.pool = self.pool
326 self._set_pos0_for_resume()
327 else:
328 self._initialise_sampler()
329 self._set_pos0()
330 return self._sampler
332 def write_chains_to_file(self, sample):
333 chain_file = self.checkpoint_info.chain_file
334 temp_chain_file = chain_file + ".temp"
335 if self.prerelease:
336 points = np.hstack([sample.coords, sample.blobs])
337 else:
338 points = np.hstack([sample[0], np.array(sample[3])])
339 data_to_write = "\n".join(
340 self.checkpoint_info.chain_template.format(ii, *point)
341 for ii, point in enumerate(points)
342 )
343 with open(temp_chain_file, "w") as ff:
344 ff.write(data_to_write)
345 with open(temp_chain_file, "rb") as ftemp, open(chain_file, "ab") as fchain:
346 shutil.copyfileobj(ftemp, fchain)
347 os.remove(temp_chain_file)
349 @property
350 def _previous_iterations(self):
351 """Returns the number of iterations that the sampler has saved
353 This is used when loading in a sampler from a pickle file to figure out
354 how much of the run has already been completed
355 """
356 try:
357 return len(self.sampler.blobs)
358 except AttributeError:
359 return 0
361 def _draw_pos0_from_prior(self):
362 return np.array(
363 [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
364 )
366 @property
367 def _pos0_shape(self):
368 return (self.nwalkers, self.ndim)
370 def _set_pos0(self):
371 if self.pos0 is not None:
372 logger.debug("Using given initial positions for walkers")
373 if isinstance(self.pos0, DataFrame):
374 self.pos0 = self.pos0[self.search_parameter_keys].values
375 elif type(self.pos0) in (list, np.ndarray):
376 self.pos0 = np.squeeze(self.pos0)
378 if self.pos0.shape != self._pos0_shape:
379 raise ValueError("Input pos0 should be of shape ndim, nwalkers")
380 logger.debug("Checking input pos0")
381 for draw in self.pos0:
382 self.check_draw(draw)
383 else:
384 logger.debug("Generating initial walker positions from prior")
385 self.pos0 = self._draw_pos0_from_prior()
387 def _set_pos0_for_resume(self):
388 self.pos0 = self.sampler.chain[:, -1, :]
390 @signal_wrapper
391 def run_sampler(self):
392 self._setup_pool()
393 from tqdm.auto import tqdm
395 sampler_function_kwargs = self.sampler_function_kwargs
396 iterations = sampler_function_kwargs.pop("iterations")
397 iterations -= self._previous_iterations
399 if self.prerelease:
400 sampler_function_kwargs["initial_state"] = self.pos0
401 else:
402 sampler_function_kwargs["p0"] = self.pos0
404 # main iteration loop
405 iterator = self.sampler.sample(iterations=iterations, **sampler_function_kwargs)
406 if self.verbose:
407 iterator = tqdm(iterator, total=iterations)
408 for sample in iterator:
409 self.write_chains_to_file(sample)
410 if self.verbose:
411 iterator.close()
412 self.write_current_state()
413 self._close_pool()
415 self.result.sampler_output = np.nan
416 self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
417 self.print_nburn_logging_info()
419 self._generate_result()
421 self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape(
422 (-1, self.ndim)
423 )
424 self.result.walkers = self.sampler.chain
425 return self.result
427 def _generate_result(self):
428 self.result.nburn = self.nburn
429 self.calc_likelihood_count()
430 if self.result.nburn > self.nsteps:
431 raise SamplerError(
432 "The run has finished, but the chain is not burned in: "
433 f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})."
434 " Try increasing the number of steps."
435 )
436 blobs = np.array(self.sampler.blobs)
437 blobs_trimmed = blobs[self.nburn :, :, :].reshape((-1, 2))
438 log_likelihoods, log_priors = blobs_trimmed.T
439 self.result.log_likelihood_evaluations = log_likelihoods
440 self.result.log_prior_evaluations = log_priors
441 self.result.log_evidence = np.nan
442 self.result.log_evidence_err = np.nan