Coverage for bilby/bilby_mcmc/chain.py: 93%
307 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import numpy as np
2import pandas as pd
3from packaging import version
5from ..core.sampler.base_sampler import SamplerError
6from ..core.utils import logger
7from .utils import LOGLKEY, LOGLLATEXKEY, LOGPKEY, LOGPLATEXKEY
10class Chain(object):
11 def __init__(
12 self,
13 initial_sample,
14 burn_in_nact=1,
15 thin_by_nact=1,
16 fixed_discard=0,
17 autocorr_c=5,
18 min_tau=1,
19 fixed_tau=None,
20 tau_window=None,
21 block_length=100000,
22 ):
23 """Object to store a single mcmc chain
25 Parameters
26 ----------
27 initial_sample: bilby.bilby_mcmc.chain.Sample
28 The starting point of the chain
29 burn_in_nact, thin_by_nact : int (1, 1)
30 The number of autocorrelation times (tau) to discard for burn-in
31 and the multiplicative factor to thin by (thin_by_nact < 1). I.e
32 burn_in_nact=10 and thin_by_nact=1 will discard 10*tau samples from
33 the start of the chain, then thin the final chain by a factor
34 of 1*tau (resulting in independent samples).
35 fixed_discard: int (0)
36 A fixed minimum number of samples to discard (can be used to
37 override the burn_in_nact if it is too small).
38 autocorr_c: float (5)
39 The step size of the window search used by emcee.autocorr when
40 estimating the autocorrelation time.
41 min_tau: int (1)
42 A minimum value for the autocorrelation time.
43 fixed_tau: int (None)
44 A fixed value for the autocorrelation (overrides the automated
45 autocorrelation time estimation). Used in testing.
46 tau_window: int (None)
47 Only calculate the autocorrelation time in a trailing window. If
48 None (default) this method is not used.
49 block_length: int
50 The incremental size to extend the array by when it runs out of
51 space.
52 """
53 self.autocorr_c = autocorr_c
54 self.min_tau = min_tau
55 self.burn_in_nact = burn_in_nact
56 self.thin_by_nact = thin_by_nact
57 self.block_length = block_length
58 self.fixed_discard = int(fixed_discard)
59 self.fixed_tau = fixed_tau
60 self.tau_window = tau_window
62 self.ndim = initial_sample.ndim
63 self.current_sample = initial_sample
64 self.keys = self.current_sample.keys
65 self.parameter_keys = self.current_sample.parameter_keys
67 # Initialize chain
68 self._chain_array = self._get_zero_chain_array()
69 self._chain_array_length = block_length
70 self.position = -1
71 self.max_log_likelihood = -np.inf
72 self.max_tau_dict = {}
73 self.converged = False
74 self.cached_tau_count = 0
75 self._minimum_index_proposal = 0
76 self._minimum_index_adapt = 0
77 self._last_minimum_index = (0, 0, "I")
78 self.last_full_tau_dict = {key: np.inf for key in self.parameter_keys}
80 # Append the initial sample
81 self.append(self.current_sample)
83 def _get_zero_chain_array(self):
84 return np.zeros((self.block_length, self.ndim + 2), dtype=np.float64)
86 def _extend_chain_array(self):
87 self._chain_array = np.concatenate(
88 (self._chain_array, self._get_zero_chain_array()), axis=0
89 )
90 self._chain_array_length = len(self._chain_array)
92 @property
93 def current_sample(self):
94 return self._current_sample.copy()
96 @current_sample.setter
97 def current_sample(self, current_sample):
98 self._current_sample = current_sample
100 def append(self, sample):
101 self.position += 1
103 # Extend the array if needed
104 if self.position >= self._chain_array_length:
105 self._extend_chain_array()
107 # Store the current sample and append to the array
108 self.current_sample = sample
109 self._chain_array[self.position] = sample.list
111 # Update the maximum log_likelihood
112 if sample[LOGLKEY] > self.max_log_likelihood:
113 self.max_log_likelihood = sample[LOGLKEY]
115 def __getitem__(self, index):
116 if index < 0:
117 index = index + self.position + 1
119 if index <= self.position:
120 values = self._chain_array[index]
121 return Sample({k: v for k, v in zip(self.keys, values)})
122 else:
123 raise SamplerError(f"Requested index {index} out of bounds")
125 def __setitem__(self, index, sample):
126 if index < 0:
127 index = index + self.position + 1
129 self._chain_array[index] = sample.list
131 def key_to_idx(self, key):
132 return self.keys.index(key)
134 def get_1d_array(self, key):
135 return self._chain_array[: 1 + self.position, self.key_to_idx(key)]
137 @property
138 def _random_idx(self):
139 from ..core.utils.random import rng
141 mindex = self._last_minimum_index[1]
142 # Check if mindex exceeds current position by 10 ACT: if so use a random sample
143 # otherwise we draw only from the chain past the minimum_index
144 if np.isinf(self.tau_last) or self.position - mindex < 10 * self.tau_last:
145 mindex = 0
146 return rng.integers(mindex, self.position + 1)
148 @property
149 def random_sample(self):
150 return self[self._random_idx]
152 @property
153 def fixed_discard(self):
154 return self._fixed_discard
156 @fixed_discard.setter
157 def fixed_discard(self, fixed_discard):
158 self._fixed_discard = int(fixed_discard)
160 @property
161 def minimum_index(self):
162 """This calculates a minimum index from which to discard samples
164 A number of methods are provided for the calculation. A subset are
165 switched off (by `if False` statements) for future development
167 """
168 position = self.position
170 # Return cached minimum index
171 last_minimum_index = self._last_minimum_index
172 if position == last_minimum_index[0]:
173 return int(last_minimum_index[1])
175 # If fixed discard is not yet reached, just return that
176 if position < self.fixed_discard:
177 self.minimum_index_method = "FD"
178 return self.fixed_discard
180 # Initialize list of minimum index methods with the fixed discard (FD)
181 minimum_index_list = [self.fixed_discard]
182 minimum_index_method_list = ["FD"]
184 # Calculate minimum index from tau
185 if self.tau_last < np.inf:
186 tau = self.tau_last
187 elif len(self.max_tau_dict) == 0:
188 # Bootstrap calculating tau when minimum index has not yet been calculated
189 tau = self._tau_for_full_chain
190 else:
191 tau = np.inf
193 if tau < np.inf:
194 minimum_index_list.append(self.burn_in_nact * tau)
195 minimum_index_method_list.append(f"{self.burn_in_nact}tau")
197 # Calculate points when log-posterior is within z std of the mean
198 if True:
199 zfactor = 1
200 N = 100
201 delta_lnP = zfactor * self.ndim / 2
202 logl = self.get_1d_array(LOGLKEY)
203 log_prior = self.get_1d_array(LOGPKEY)
204 log_posterior = logl + log_prior
205 max_posterior = np.max(log_posterior)
207 ave = pd.Series(log_posterior).rolling(window=N).mean().iloc[N - 1 :]
208 delta = max_posterior - ave
209 passes = ave[delta < delta_lnP]
210 if len(passes) > 0:
211 minimum_index_list.append(passes.index[0] + 1)
212 minimum_index_method_list.append(f"z{zfactor}")
214 # Add last minimum_index_method
215 if False:
216 minimum_index_list.append(last_minimum_index[1])
217 minimum_index_method_list.append(last_minimum_index[2])
219 # Minimum index set by proposals
220 minimum_index_list.append(self.minimum_index_proposal)
221 minimum_index_method_list.append("PR")
223 # Minimum index set by temperature adaptation
224 minimum_index_list.append(self.minimum_index_adapt)
225 minimum_index_method_list.append("AD")
227 # Calculate the maximum minimum index and associated method (reporting)
228 minimum_index = int(np.max(minimum_index_list))
229 minimum_index_method = minimum_index_method_list[np.argmax(minimum_index_list)]
231 # Cache the method
232 self._last_minimum_index = (position, minimum_index, minimum_index_method)
233 self.minimum_index_method = minimum_index_method
235 return minimum_index
237 @property
238 def minimum_index_proposal(self):
239 return self._minimum_index_proposal
241 @minimum_index_proposal.setter
242 def minimum_index_proposal(self, minimum_index_proposal):
243 if minimum_index_proposal > self._minimum_index_proposal:
244 self._minimum_index_proposal = minimum_index_proposal
246 @property
247 def minimum_index_adapt(self):
248 return self._minimum_index_adapt
250 @minimum_index_adapt.setter
251 def minimum_index_adapt(self, minimum_index_adapt):
252 if minimum_index_adapt > self._minimum_index_adapt:
253 self._minimum_index_adapt = minimum_index_adapt
255 @property
256 def tau(self):
257 """The maximum ACT over all parameters"""
259 if self.position in self.max_tau_dict:
260 # If we have the ACT at the current position, return it
261 return self.max_tau_dict[self.position]
262 elif (
263 self.tau_last < np.inf
264 and self.cached_tau_count < 50
265 and self.nsamples_last > 50
266 ):
267 # If we have a recent ACT return it
268 self.cached_tau_count += 1
269 return self.tau_last
270 else:
271 # Calculate the ACT
272 return self.tau_nocache
274 @property
275 def tau_nocache(self):
276 """Calculate tau forcing a recalculation (no cached tau)"""
277 tau = max(self.tau_dict.values())
278 self.max_tau_dict[self.position] = tau
279 self.cached_tau_count = 0
280 return tau
282 @property
283 def tau_last(self):
284 """Return the last-calculated tau if it exists, else inf"""
285 if len(self.max_tau_dict) > 0:
286 return list(self.max_tau_dict.values())[-1]
287 else:
288 return np.inf
290 @property
291 def _tau_for_full_chain(self):
292 """The maximum ACT over all parameters"""
293 return max(self._tau_dict_for_full_chain.values())
295 @property
296 def _tau_dict_for_full_chain(self):
297 return self._calculate_tau_dict(minimum_index=0)
299 @property
300 def tau_dict(self):
301 """Calculate a dictionary of tau (ACT) for every parameter"""
302 return self._calculate_tau_dict(self.minimum_index)
304 def _calculate_tau_dict(self, minimum_index):
305 """Calculate a dictionary of tau (ACT) for every parameter"""
306 logger.debug(f"Calculating tau_dict {self}")
308 # If there are too few samples to calculate tau
309 if (self.position - minimum_index) < 2 * self.autocorr_c:
310 return {key: np.inf for key in self.parameter_keys}
312 # Choose minimimum index for the ACT calculation
313 last_tau = self.tau_last
314 if self.tau_window is not None and last_tau < np.inf:
315 minimum_index_for_act = max(
316 minimum_index, int(self.position - self.tau_window * last_tau)
317 )
318 else:
319 minimum_index_for_act = minimum_index
321 # Calculate a dictionary of tau's for each parameter
322 taus = {}
323 for key in self.parameter_keys:
324 if self.fixed_tau is None:
325 x = self.get_1d_array(key)[minimum_index_for_act:]
326 tau = calculate_tau(x, self.autocorr_c)
327 taux = round(tau, 1)
328 else:
329 taux = self.fixed_tau
330 taus[key] = max(taux, self.min_tau)
332 # Cache the last tau dictionary for future use
333 self.last_full_tau_dict = taus
335 return taus
337 @property
338 def thin(self):
339 if np.isfinite(self.tau):
340 return np.max([1, int(self.thin_by_nact * self.tau)])
341 else:
342 return 1
344 @property
345 def nsamples(self):
346 nuseable_steps = self.position - self.minimum_index
347 n_independent_samples = nuseable_steps / self.tau
348 nsamples = int(n_independent_samples / self.thin_by_nact)
349 if nuseable_steps >= nsamples:
350 return nsamples
351 else:
352 return 0
354 @property
355 def nsamples_last(self):
356 nuseable_steps = self.position - self.minimum_index
357 return int(nuseable_steps / (self.thin_by_nact * self.tau_last))
359 @property
360 def samples(self):
361 samples = self._chain_array[self.minimum_index : self.position : self.thin]
362 return pd.DataFrame(samples, columns=self.keys)
364 def plot(self, outdir=".", label="label", priors=None, all_samples=None):
365 import matplotlib.pyplot as plt
367 fig, axes = plt.subplots(
368 nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim))
369 )
370 scatter_kwargs = dict(
371 lw=0,
372 marker="o",
373 )
374 K = 1000
376 nburn = self.minimum_index
377 plot_setups = zip(
378 [0, nburn, nburn],
379 [nburn, self.position, self.position],
380 [1, 1, self.thin], # Thin-by factor
381 ["tab:red", "tab:grey", "tab:blue"], # Color
382 [0.5, 0.05, 0.5], # Alpha
383 [1, 1, 1], # Marker size
384 )
386 position_indexes = np.arange(self.position + 1)
388 # Plot the traceplots
389 for (start, stop, thin, color, alpha, ms) in plot_setups:
390 for ax, key in zip(axes[:, 0], self.keys):
391 xx = position_indexes[start:stop:thin] / K
392 yy = self.get_1d_array(key)[start:stop:thin]
394 # Downsample plots to max_pts: avoid memory issues
395 max_pts = 10000
396 while len(xx) > max_pts:
397 xx = xx[::2]
398 yy = yy[::2]
400 ax.plot(
401 xx,
402 yy,
403 color=color,
404 alpha=alpha,
405 ms=ms,
406 **scatter_kwargs,
407 )
408 ax.set_ylabel(self._get_plot_label_by_key(key, priors))
409 if key not in [LOGLKEY, LOGPKEY]:
410 msg = r"$\tau=$" + f"{self.last_full_tau_dict[key]:0.1f}"
411 ax.set_title(msg)
413 # Plot the histograms
414 for ax, key in zip(axes[:, 1], self.keys):
415 if all_samples is not None:
416 yy_all = all_samples[key]
417 if np.any(np.isinf(yy_all)):
418 logger.warning(
419 f"Could not plot histogram for parameter {key} due to infinite values"
420 )
421 else:
422 ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
423 yy = self.get_1d_array(key)[nburn : self.position : self.thin]
424 if np.any(np.isinf(yy)):
425 logger.warning(
426 f"Could not plot histogram for parameter {key} due to infinite values"
427 )
428 else:
429 ax.hist(yy, bins=50, alpha=0.8, density=True)
430 ax.set_xlabel(self._get_plot_label_by_key(key, priors))
432 # Add x-axes labels to the traceplots
433 axes[-1, 0].set_xlabel(r"Iteration $[\times 10^{3}]$")
435 # Plot the calculated ACT
436 ax = axes[-1, 0]
437 tausit = np.array(list(self.max_tau_dict.keys()) + [self.position]) / K
438 taus = list(self.max_tau_dict.values()) + [self.tau_last]
439 ax.plot(tausit, taus, color="C3")
440 ax.set(ylabel=r"Maximum $\tau$")
442 axes[-1, 1].set_axis_off()
444 filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
445 msg = [
446 r"Maximum $\tau$" + f"={self.tau:0.1f} ",
447 r"$n_{\rm samples}=$" + f"{self.nsamples} ",
448 ]
449 if self.thin_by_nact != 1:
450 msg += [
451 r"$n_{\rm samples}^{\rm eff}=$"
452 + f"{int(self.nsamples * self.thin_by_nact)} "
453 ]
454 fig.suptitle(
455 "| ".join(msg),
456 y=1,
457 )
458 fig.tight_layout()
459 fig.savefig(filename, dpi=200)
460 plt.close(fig)
462 @staticmethod
463 def _get_plot_label_by_key(key, priors=None):
464 if priors is not None and key in priors:
465 return priors[key].latex_label
466 elif key == LOGLKEY:
467 return LOGLLATEXKEY
468 elif key == LOGPKEY:
469 return LOGPLATEXKEY
470 else:
471 return key
474class Sample(object):
475 def __init__(self, sample_dict):
476 """A single sample
478 Parameters
479 ----------
480 sample_dict: dict
481 A dictionary of the sample
482 """
484 self.sample_dict = sample_dict
485 self.keys = list(sample_dict.keys())
486 self.parameter_keys = [k for k in self.keys if k not in [LOGPKEY, LOGLKEY]]
487 self.ndim = len(self.parameter_keys)
489 def __getitem__(self, key):
490 return self.sample_dict[key]
492 def __setitem__(self, key, value):
493 self.sample_dict[key] = value
494 if key not in self.keys:
495 self.keys = list(self.sample_dict.keys())
497 @property
498 def list(self):
499 return list(self.sample_dict.values())
501 def __repr__(self):
502 return str(self.sample_dict)
504 @property
505 def parameter_only_dict(self):
506 return {key: self.sample_dict[key] for key in self.parameter_keys}
508 @property
509 def dict(self):
510 return {key: self.sample_dict[key] for key in self.keys}
512 def as_dict(self, keys=None):
513 sdict = self.dict
514 if keys is None:
515 return sdict
516 else:
517 return {key: sdict[key] for key in keys}
519 def __eq__(self, other_sample):
520 return self.list == other_sample.list
522 def copy(self):
523 return Sample(self.sample_dict.copy())
526def calculate_tau(x, autocorr_c=5):
527 import emcee
529 if version.parse(emcee.__version__) < version.parse("3"):
530 raise SamplerError("bilby-mcmc requires emcee > 3.0 for autocorr analysis")
532 if np.all(np.diff(x) == 0):
533 return np.inf
534 try:
535 # Hard code tol=1: we perform this check internally
536 tau = emcee.autocorr.integrated_time(x, c=autocorr_c, tol=1)[0]
537 if np.isnan(tau):
538 tau = np.inf
539 return tau
540 except emcee.autocorr.AutocorrError:
541 return np.inf