Coverage for bilby/bilby_mcmc/chain.py: 93%

307 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import numpy as np 

2import pandas as pd 

3from packaging import version 

4 

5from ..core.sampler.base_sampler import SamplerError 

6from ..core.utils import logger 

7from .utils import LOGLKEY, LOGLLATEXKEY, LOGPKEY, LOGPLATEXKEY 

8 

9 

10class Chain(object): 

11 def __init__( 

12 self, 

13 initial_sample, 

14 burn_in_nact=1, 

15 thin_by_nact=1, 

16 fixed_discard=0, 

17 autocorr_c=5, 

18 min_tau=1, 

19 fixed_tau=None, 

20 tau_window=None, 

21 block_length=100000, 

22 ): 

23 """Object to store a single mcmc chain 

24 

25 Parameters 

26 ---------- 

27 initial_sample: bilby.bilby_mcmc.chain.Sample 

28 The starting point of the chain 

29 burn_in_nact, thin_by_nact : int (1, 1) 

30 The number of autocorrelation times (tau) to discard for burn-in 

31 and the multiplicative factor to thin by (thin_by_nact < 1). I.e 

32 burn_in_nact=10 and thin_by_nact=1 will discard 10*tau samples from 

33 the start of the chain, then thin the final chain by a factor 

34 of 1*tau (resulting in independent samples). 

35 fixed_discard: int (0) 

36 A fixed minimum number of samples to discard (can be used to 

37 override the burn_in_nact if it is too small). 

38 autocorr_c: float (5) 

39 The step size of the window search used by emcee.autocorr when 

40 estimating the autocorrelation time. 

41 min_tau: int (1) 

42 A minimum value for the autocorrelation time. 

43 fixed_tau: int (None) 

44 A fixed value for the autocorrelation (overrides the automated 

45 autocorrelation time estimation). Used in testing. 

46 tau_window: int (None) 

47 Only calculate the autocorrelation time in a trailing window. If 

48 None (default) this method is not used. 

49 block_length: int 

50 The incremental size to extend the array by when it runs out of 

51 space. 

52 """ 

53 self.autocorr_c = autocorr_c 

54 self.min_tau = min_tau 

55 self.burn_in_nact = burn_in_nact 

56 self.thin_by_nact = thin_by_nact 

57 self.block_length = block_length 

58 self.fixed_discard = int(fixed_discard) 

59 self.fixed_tau = fixed_tau 

60 self.tau_window = tau_window 

61 

62 self.ndim = initial_sample.ndim 

63 self.current_sample = initial_sample 

64 self.keys = self.current_sample.keys 

65 self.parameter_keys = self.current_sample.parameter_keys 

66 

67 # Initialize chain 

68 self._chain_array = self._get_zero_chain_array() 

69 self._chain_array_length = block_length 

70 self.position = -1 

71 self.max_log_likelihood = -np.inf 

72 self.max_tau_dict = {} 

73 self.converged = False 

74 self.cached_tau_count = 0 

75 self._minimum_index_proposal = 0 

76 self._minimum_index_adapt = 0 

77 self._last_minimum_index = (0, 0, "I") 

78 self.last_full_tau_dict = {key: np.inf for key in self.parameter_keys} 

79 

80 # Append the initial sample 

81 self.append(self.current_sample) 

82 

83 def _get_zero_chain_array(self): 

84 return np.zeros((self.block_length, self.ndim + 2), dtype=np.float64) 

85 

86 def _extend_chain_array(self): 

87 self._chain_array = np.concatenate( 

88 (self._chain_array, self._get_zero_chain_array()), axis=0 

89 ) 

90 self._chain_array_length = len(self._chain_array) 

91 

92 @property 

93 def current_sample(self): 

94 return self._current_sample.copy() 

95 

96 @current_sample.setter 

97 def current_sample(self, current_sample): 

98 self._current_sample = current_sample 

99 

100 def append(self, sample): 

101 self.position += 1 

102 

103 # Extend the array if needed 

104 if self.position >= self._chain_array_length: 

105 self._extend_chain_array() 

106 

107 # Store the current sample and append to the array 

108 self.current_sample = sample 

109 self._chain_array[self.position] = sample.list 

110 

111 # Update the maximum log_likelihood 

112 if sample[LOGLKEY] > self.max_log_likelihood: 

113 self.max_log_likelihood = sample[LOGLKEY] 

114 

115 def __getitem__(self, index): 

116 if index < 0: 

117 index = index + self.position + 1 

118 

119 if index <= self.position: 

120 values = self._chain_array[index] 

121 return Sample({k: v for k, v in zip(self.keys, values)}) 

122 else: 

123 raise SamplerError(f"Requested index {index} out of bounds") 

124 

125 def __setitem__(self, index, sample): 

126 if index < 0: 

127 index = index + self.position + 1 

128 

129 self._chain_array[index] = sample.list 

130 

131 def key_to_idx(self, key): 

132 return self.keys.index(key) 

133 

134 def get_1d_array(self, key): 

135 return self._chain_array[: 1 + self.position, self.key_to_idx(key)] 

136 

137 @property 

138 def _random_idx(self): 

139 from ..core.utils.random import rng 

140 

141 mindex = self._last_minimum_index[1] 

142 # Check if mindex exceeds current position by 10 ACT: if so use a random sample 

143 # otherwise we draw only from the chain past the minimum_index 

144 if np.isinf(self.tau_last) or self.position - mindex < 10 * self.tau_last: 

145 mindex = 0 

146 return rng.integers(mindex, self.position + 1) 

147 

148 @property 

149 def random_sample(self): 

150 return self[self._random_idx] 

151 

152 @property 

153 def fixed_discard(self): 

154 return self._fixed_discard 

155 

156 @fixed_discard.setter 

157 def fixed_discard(self, fixed_discard): 

158 self._fixed_discard = int(fixed_discard) 

159 

160 @property 

161 def minimum_index(self): 

162 """This calculates a minimum index from which to discard samples 

163 

164 A number of methods are provided for the calculation. A subset are 

165 switched off (by `if False` statements) for future development 

166 

167 """ 

168 position = self.position 

169 

170 # Return cached minimum index 

171 last_minimum_index = self._last_minimum_index 

172 if position == last_minimum_index[0]: 

173 return int(last_minimum_index[1]) 

174 

175 # If fixed discard is not yet reached, just return that 

176 if position < self.fixed_discard: 

177 self.minimum_index_method = "FD" 

178 return self.fixed_discard 

179 

180 # Initialize list of minimum index methods with the fixed discard (FD) 

181 minimum_index_list = [self.fixed_discard] 

182 minimum_index_method_list = ["FD"] 

183 

184 # Calculate minimum index from tau 

185 if self.tau_last < np.inf: 

186 tau = self.tau_last 

187 elif len(self.max_tau_dict) == 0: 

188 # Bootstrap calculating tau when minimum index has not yet been calculated 

189 tau = self._tau_for_full_chain 

190 else: 

191 tau = np.inf 

192 

193 if tau < np.inf: 

194 minimum_index_list.append(self.burn_in_nact * tau) 

195 minimum_index_method_list.append(f"{self.burn_in_nact}tau") 

196 

197 # Calculate points when log-posterior is within z std of the mean 

198 if True: 

199 zfactor = 1 

200 N = 100 

201 delta_lnP = zfactor * self.ndim / 2 

202 logl = self.get_1d_array(LOGLKEY) 

203 log_prior = self.get_1d_array(LOGPKEY) 

204 log_posterior = logl + log_prior 

205 max_posterior = np.max(log_posterior) 

206 

207 ave = pd.Series(log_posterior).rolling(window=N).mean().iloc[N - 1 :] 

208 delta = max_posterior - ave 

209 passes = ave[delta < delta_lnP] 

210 if len(passes) > 0: 

211 minimum_index_list.append(passes.index[0] + 1) 

212 minimum_index_method_list.append(f"z{zfactor}") 

213 

214 # Add last minimum_index_method 

215 if False: 

216 minimum_index_list.append(last_minimum_index[1]) 

217 minimum_index_method_list.append(last_minimum_index[2]) 

218 

219 # Minimum index set by proposals 

220 minimum_index_list.append(self.minimum_index_proposal) 

221 minimum_index_method_list.append("PR") 

222 

223 # Minimum index set by temperature adaptation 

224 minimum_index_list.append(self.minimum_index_adapt) 

225 minimum_index_method_list.append("AD") 

226 

227 # Calculate the maximum minimum index and associated method (reporting) 

228 minimum_index = int(np.max(minimum_index_list)) 

229 minimum_index_method = minimum_index_method_list[np.argmax(minimum_index_list)] 

230 

231 # Cache the method 

232 self._last_minimum_index = (position, minimum_index, minimum_index_method) 

233 self.minimum_index_method = minimum_index_method 

234 

235 return minimum_index 

236 

237 @property 

238 def minimum_index_proposal(self): 

239 return self._minimum_index_proposal 

240 

241 @minimum_index_proposal.setter 

242 def minimum_index_proposal(self, minimum_index_proposal): 

243 if minimum_index_proposal > self._minimum_index_proposal: 

244 self._minimum_index_proposal = minimum_index_proposal 

245 

246 @property 

247 def minimum_index_adapt(self): 

248 return self._minimum_index_adapt 

249 

250 @minimum_index_adapt.setter 

251 def minimum_index_adapt(self, minimum_index_adapt): 

252 if minimum_index_adapt > self._minimum_index_adapt: 

253 self._minimum_index_adapt = minimum_index_adapt 

254 

255 @property 

256 def tau(self): 

257 """The maximum ACT over all parameters""" 

258 

259 if self.position in self.max_tau_dict: 

260 # If we have the ACT at the current position, return it 

261 return self.max_tau_dict[self.position] 

262 elif ( 

263 self.tau_last < np.inf 

264 and self.cached_tau_count < 50 

265 and self.nsamples_last > 50 

266 ): 

267 # If we have a recent ACT return it 

268 self.cached_tau_count += 1 

269 return self.tau_last 

270 else: 

271 # Calculate the ACT 

272 return self.tau_nocache 

273 

274 @property 

275 def tau_nocache(self): 

276 """Calculate tau forcing a recalculation (no cached tau)""" 

277 tau = max(self.tau_dict.values()) 

278 self.max_tau_dict[self.position] = tau 

279 self.cached_tau_count = 0 

280 return tau 

281 

282 @property 

283 def tau_last(self): 

284 """Return the last-calculated tau if it exists, else inf""" 

285 if len(self.max_tau_dict) > 0: 

286 return list(self.max_tau_dict.values())[-1] 

287 else: 

288 return np.inf 

289 

290 @property 

291 def _tau_for_full_chain(self): 

292 """The maximum ACT over all parameters""" 

293 return max(self._tau_dict_for_full_chain.values()) 

294 

295 @property 

296 def _tau_dict_for_full_chain(self): 

297 return self._calculate_tau_dict(minimum_index=0) 

298 

299 @property 

300 def tau_dict(self): 

301 """Calculate a dictionary of tau (ACT) for every parameter""" 

302 return self._calculate_tau_dict(self.minimum_index) 

303 

304 def _calculate_tau_dict(self, minimum_index): 

305 """Calculate a dictionary of tau (ACT) for every parameter""" 

306 logger.debug(f"Calculating tau_dict {self}") 

307 

308 # If there are too few samples to calculate tau 

309 if (self.position - minimum_index) < 2 * self.autocorr_c: 

310 return {key: np.inf for key in self.parameter_keys} 

311 

312 # Choose minimimum index for the ACT calculation 

313 last_tau = self.tau_last 

314 if self.tau_window is not None and last_tau < np.inf: 

315 minimum_index_for_act = max( 

316 minimum_index, int(self.position - self.tau_window * last_tau) 

317 ) 

318 else: 

319 minimum_index_for_act = minimum_index 

320 

321 # Calculate a dictionary of tau's for each parameter 

322 taus = {} 

323 for key in self.parameter_keys: 

324 if self.fixed_tau is None: 

325 x = self.get_1d_array(key)[minimum_index_for_act:] 

326 tau = calculate_tau(x, self.autocorr_c) 

327 taux = round(tau, 1) 

328 else: 

329 taux = self.fixed_tau 

330 taus[key] = max(taux, self.min_tau) 

331 

332 # Cache the last tau dictionary for future use 

333 self.last_full_tau_dict = taus 

334 

335 return taus 

336 

337 @property 

338 def thin(self): 

339 if np.isfinite(self.tau): 

340 return np.max([1, int(self.thin_by_nact * self.tau)]) 

341 else: 

342 return 1 

343 

344 @property 

345 def nsamples(self): 

346 nuseable_steps = self.position - self.minimum_index 

347 n_independent_samples = nuseable_steps / self.tau 

348 nsamples = int(n_independent_samples / self.thin_by_nact) 

349 if nuseable_steps >= nsamples: 

350 return nsamples 

351 else: 

352 return 0 

353 

354 @property 

355 def nsamples_last(self): 

356 nuseable_steps = self.position - self.minimum_index 

357 return int(nuseable_steps / (self.thin_by_nact * self.tau_last)) 

358 

359 @property 

360 def samples(self): 

361 samples = self._chain_array[self.minimum_index : self.position : self.thin] 

362 return pd.DataFrame(samples, columns=self.keys) 

363 

364 def plot(self, outdir=".", label="label", priors=None, all_samples=None): 

365 import matplotlib.pyplot as plt 

366 

367 fig, axes = plt.subplots( 

368 nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim)) 

369 ) 

370 scatter_kwargs = dict( 

371 lw=0, 

372 marker="o", 

373 ) 

374 K = 1000 

375 

376 nburn = self.minimum_index 

377 plot_setups = zip( 

378 [0, nburn, nburn], 

379 [nburn, self.position, self.position], 

380 [1, 1, self.thin], # Thin-by factor 

381 ["tab:red", "tab:grey", "tab:blue"], # Color 

382 [0.5, 0.05, 0.5], # Alpha 

383 [1, 1, 1], # Marker size 

384 ) 

385 

386 position_indexes = np.arange(self.position + 1) 

387 

388 # Plot the traceplots 

389 for (start, stop, thin, color, alpha, ms) in plot_setups: 

390 for ax, key in zip(axes[:, 0], self.keys): 

391 xx = position_indexes[start:stop:thin] / K 

392 yy = self.get_1d_array(key)[start:stop:thin] 

393 

394 # Downsample plots to max_pts: avoid memory issues 

395 max_pts = 10000 

396 while len(xx) > max_pts: 

397 xx = xx[::2] 

398 yy = yy[::2] 

399 

400 ax.plot( 

401 xx, 

402 yy, 

403 color=color, 

404 alpha=alpha, 

405 ms=ms, 

406 **scatter_kwargs, 

407 ) 

408 ax.set_ylabel(self._get_plot_label_by_key(key, priors)) 

409 if key not in [LOGLKEY, LOGPKEY]: 

410 msg = r"$\tau=$" + f"{self.last_full_tau_dict[key]:0.1f}" 

411 ax.set_title(msg) 

412 

413 # Plot the histograms 

414 for ax, key in zip(axes[:, 1], self.keys): 

415 if all_samples is not None: 

416 yy_all = all_samples[key] 

417 if np.any(np.isinf(yy_all)): 

418 logger.warning( 

419 f"Could not plot histogram for parameter {key} due to infinite values" 

420 ) 

421 else: 

422 ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k") 

423 yy = self.get_1d_array(key)[nburn : self.position : self.thin] 

424 if np.any(np.isinf(yy)): 

425 logger.warning( 

426 f"Could not plot histogram for parameter {key} due to infinite values" 

427 ) 

428 else: 

429 ax.hist(yy, bins=50, alpha=0.8, density=True) 

430 ax.set_xlabel(self._get_plot_label_by_key(key, priors)) 

431 

432 # Add x-axes labels to the traceplots 

433 axes[-1, 0].set_xlabel(r"Iteration $[\times 10^{3}]$") 

434 

435 # Plot the calculated ACT 

436 ax = axes[-1, 0] 

437 tausit = np.array(list(self.max_tau_dict.keys()) + [self.position]) / K 

438 taus = list(self.max_tau_dict.values()) + [self.tau_last] 

439 ax.plot(tausit, taus, color="C3") 

440 ax.set(ylabel=r"Maximum $\tau$") 

441 

442 axes[-1, 1].set_axis_off() 

443 

444 filename = "{}/{}_checkpoint_trace.png".format(outdir, label) 

445 msg = [ 

446 r"Maximum $\tau$" + f"={self.tau:0.1f} ", 

447 r"$n_{\rm samples}=$" + f"{self.nsamples} ", 

448 ] 

449 if self.thin_by_nact != 1: 

450 msg += [ 

451 r"$n_{\rm samples}^{\rm eff}=$" 

452 + f"{int(self.nsamples * self.thin_by_nact)} " 

453 ] 

454 fig.suptitle( 

455 "| ".join(msg), 

456 y=1, 

457 ) 

458 fig.tight_layout() 

459 fig.savefig(filename, dpi=200) 

460 plt.close(fig) 

461 

462 @staticmethod 

463 def _get_plot_label_by_key(key, priors=None): 

464 if priors is not None and key in priors: 

465 return priors[key].latex_label 

466 elif key == LOGLKEY: 

467 return LOGLLATEXKEY 

468 elif key == LOGPKEY: 

469 return LOGPLATEXKEY 

470 else: 

471 return key 

472 

473 

474class Sample(object): 

475 def __init__(self, sample_dict): 

476 """A single sample 

477 

478 Parameters 

479 ---------- 

480 sample_dict: dict 

481 A dictionary of the sample 

482 """ 

483 

484 self.sample_dict = sample_dict 

485 self.keys = list(sample_dict.keys()) 

486 self.parameter_keys = [k for k in self.keys if k not in [LOGPKEY, LOGLKEY]] 

487 self.ndim = len(self.parameter_keys) 

488 

489 def __getitem__(self, key): 

490 return self.sample_dict[key] 

491 

492 def __setitem__(self, key, value): 

493 self.sample_dict[key] = value 

494 if key not in self.keys: 

495 self.keys = list(self.sample_dict.keys()) 

496 

497 @property 

498 def list(self): 

499 return list(self.sample_dict.values()) 

500 

501 def __repr__(self): 

502 return str(self.sample_dict) 

503 

504 @property 

505 def parameter_only_dict(self): 

506 return {key: self.sample_dict[key] for key in self.parameter_keys} 

507 

508 @property 

509 def dict(self): 

510 return {key: self.sample_dict[key] for key in self.keys} 

511 

512 def as_dict(self, keys=None): 

513 sdict = self.dict 

514 if keys is None: 

515 return sdict 

516 else: 

517 return {key: sdict[key] for key in keys} 

518 

519 def __eq__(self, other_sample): 

520 return self.list == other_sample.list 

521 

522 def copy(self): 

523 return Sample(self.sample_dict.copy()) 

524 

525 

526def calculate_tau(x, autocorr_c=5): 

527 import emcee 

528 

529 if version.parse(emcee.__version__) < version.parse("3"): 

530 raise SamplerError("bilby-mcmc requires emcee > 3.0 for autocorr analysis") 

531 

532 if np.all(np.diff(x) == 0): 

533 return np.inf 

534 try: 

535 # Hard code tol=1: we perform this check internally 

536 tau = emcee.autocorr.integrated_time(x, c=autocorr_c, tol=1)[0] 

537 if np.isnan(tau): 

538 tau = np.inf 

539 return tau 

540 except emcee.autocorr.AutocorrError: 

541 return np.inf