Coverage for bilby/core/sampler/emcee.py: 35%

220 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import os 

2import shutil 

3from collections import namedtuple 

4 

5import numpy as np 

6from packaging import version 

7from pandas import DataFrame 

8 

9from ..utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump 

10from .base_sampler import MCMCSampler, SamplerError, signal_wrapper 

11from .ptemcee import LikePriorEvaluator 

12 

13_evaluator = LikePriorEvaluator() 

14 

15 

16class Emcee(MCMCSampler): 

17 """bilby wrapper emcee (https://github.com/dfm/emcee) 

18 

19 All positional and keyword arguments (i.e., the args and kwargs) passed to 

20 `run_sampler` will be propagated to `emcee.EnsembleSampler`, see 

21 documentation for that class for further help. Under Other Parameters, we 

22 list commonly used kwargs and the bilby defaults. 

23 

24 Parameters 

25 ========== 

26 nwalkers: int, (500) 

27 The number of walkers 

28 nsteps: int, (100) 

29 The number of steps 

30 nburn: int (None) 

31 If given, the fixed number of steps to discard as burn-in. These will 

32 be discarded from the total number of steps set by `nsteps` and 

33 therefore the value must be greater than `nsteps`. Else, nburn is 

34 estimated from the autocorrelation time 

35 burn_in_fraction: float, (0.25) 

36 The fraction of steps to discard as burn-in in the event that the 

37 autocorrelation time cannot be calculated 

38 burn_in_act: float 

39 The number of autocorrelation times to discard as burn-in 

40 a: float (2) 

41 The proposal scale factor 

42 verbose: bool 

43 Whether to print diagnostic information during the analysis 

44 

45 """ 

46 

47 sampler_name = "emcee" 

48 default_kwargs = dict( 

49 nwalkers=500, 

50 a=2, 

51 args=[], 

52 kwargs={}, 

53 postargs=None, 

54 pool=None, 

55 live_dangerously=False, 

56 runtime_sortingfn=None, 

57 lnprob0=None, 

58 rstate0=None, 

59 blobs0=None, 

60 iterations=100, 

61 thin=1, 

62 storechain=True, 

63 mh_proposal=None, 

64 ) 

65 

66 def __init__( 

67 self, 

68 likelihood, 

69 priors, 

70 outdir="outdir", 

71 label="label", 

72 use_ratio=False, 

73 plot=False, 

74 skip_import_verification=False, 

75 pos0=None, 

76 nburn=None, 

77 burn_in_fraction=0.25, 

78 resume=True, 

79 burn_in_act=3, 

80 **kwargs, 

81 ): 

82 self._check_version() 

83 super(Emcee, self).__init__( 

84 likelihood=likelihood, 

85 priors=priors, 

86 outdir=outdir, 

87 label=label, 

88 use_ratio=use_ratio, 

89 plot=plot, 

90 skip_import_verification=skip_import_verification, 

91 **kwargs, 

92 ) 

93 self.resume = resume 

94 self.pos0 = pos0 

95 self.nburn = nburn 

96 self.burn_in_fraction = burn_in_fraction 

97 self.burn_in_act = burn_in_act 

98 self.verbose = kwargs.get("verbose", True) 

99 

100 def _check_version(self): 

101 import emcee 

102 

103 if version.parse(emcee.__version__) < version.parse("3"): 

104 self.prerelease = False 

105 else: 

106 self.prerelease = True 

107 

108 def _translate_kwargs(self, kwargs): 

109 kwargs = super()._translate_kwargs(kwargs) 

110 if "nwalkers" not in kwargs: 

111 for equiv in self.nwalkers_equiv_kwargs: 

112 if equiv in kwargs: 

113 kwargs["nwalkers"] = kwargs.pop(equiv) 

114 if "iterations" not in kwargs: 

115 if "nsteps" in kwargs: 

116 kwargs["iterations"] = kwargs.pop("nsteps") 

117 

118 @property 

119 def sampler_function_kwargs(self): 

120 keys = [ 

121 "lnprob0", 

122 "rstate0", 

123 "blobs0", 

124 "iterations", 

125 "thin", 

126 "storechain", 

127 "mh_proposal", 

128 ] 

129 

130 # updated function keywords for emcee > v2.2.1 

131 updatekeys = { 

132 "p0": "initial_state", 

133 "lnprob0": "log_prob0", 

134 "storechain": "store", 

135 } 

136 

137 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} 

138 function_kwargs["p0"] = self.pos0 

139 

140 if self.prerelease: 

141 if function_kwargs["mh_proposal"] is not None: 

142 logger.warning( 

143 "The 'mh_proposal' option is no longer used " 

144 "in emcee > 2.2.1, and will be ignored." 

145 ) 

146 del function_kwargs["mh_proposal"] 

147 

148 for key in updatekeys: 

149 if updatekeys[key] not in function_kwargs: 

150 function_kwargs[updatekeys[key]] = function_kwargs.pop(key) 

151 else: 

152 del function_kwargs[key] 

153 

154 return function_kwargs 

155 

156 @property 

157 def sampler_init_kwargs(self): 

158 init_kwargs = { 

159 key: value 

160 for key, value in self.kwargs.items() 

161 if key not in self.sampler_function_kwargs 

162 } 

163 

164 init_kwargs["lnpostfn"] = _evaluator.call_emcee 

165 init_kwargs["dim"] = self.ndim 

166 

167 # updated init keywords for emcee > v2.2.1 

168 updatekeys = {"dim": "ndim", "lnpostfn": "log_prob_fn"} 

169 

170 if self.prerelease: 

171 for key in updatekeys: 

172 if key in init_kwargs: 

173 init_kwargs[updatekeys[key]] = init_kwargs.pop(key) 

174 

175 oldfunckeys = ["p0", "lnprob0", "storechain", "mh_proposal"] 

176 for key in oldfunckeys: 

177 if key in init_kwargs: 

178 del init_kwargs[key] 

179 

180 return init_kwargs 

181 

182 @property 

183 def nburn(self): 

184 if type(self.__nburn) in [float, int]: 

185 return int(self.__nburn) 

186 elif self.result.max_autocorrelation_time is None: 

187 return int(self.burn_in_fraction * self.nsteps) 

188 else: 

189 return int(self.burn_in_act * self.result.max_autocorrelation_time) 

190 

191 @nburn.setter 

192 def nburn(self, nburn): 

193 if isinstance(nburn, (float, int)): 

194 if nburn > self.kwargs["iterations"] - 1: 

195 raise ValueError( 

196 "Number of burn-in samples must be smaller " 

197 "than the total number of iterations" 

198 ) 

199 

200 self.__nburn = nburn 

201 

202 @property 

203 def nwalkers(self): 

204 return self.kwargs["nwalkers"] 

205 

206 @property 

207 def nsteps(self): 

208 return self.kwargs["iterations"] 

209 

210 @nsteps.setter 

211 def nsteps(self, nsteps): 

212 self.kwargs["iterations"] = nsteps 

213 

214 @property 

215 def stored_chain(self): 

216 """Read the stored zero-temperature chain data in from disk""" 

217 return np.genfromtxt(self.checkpoint_info.chain_file, names=True) 

218 

219 @property 

220 def stored_samples(self): 

221 """Returns the samples stored on disk""" 

222 return self.stored_chain[self.search_parameter_keys] 

223 

224 @property 

225 def stored_loglike(self): 

226 """Returns the log-likelihood stored on disk""" 

227 return self.stored_chain["log_l"] 

228 

229 @property 

230 def stored_logprior(self): 

231 """Returns the log-prior stored on disk""" 

232 return self.stored_chain["log_p"] 

233 

234 def _init_chain_file(self): 

235 with open(self.checkpoint_info.chain_file, "w+") as ff: 

236 search_keys_str = "\t".join(self.search_parameter_keys) 

237 ff.write(f"walker\t{search_keys_str}\tlog_l\tlog_p\n") 

238 

239 @property 

240 def checkpoint_info(self): 

241 """Defines various things related to checkpointing and storing data 

242 

243 Returns 

244 ======= 

245 checkpoint_info: named_tuple 

246 An object with attributes `sampler_file`, `chain_file`, and 

247 `chain_template`. The first two give paths to where the sampler and 

248 chain data is stored, the last a formatted-str-template with which 

249 to write the chain data to disk 

250 

251 """ 

252 out_dir = os.path.join( 

253 self.outdir, f"{self.__class__.__name__.lower()}_{self.label}" 

254 ) 

255 check_directory_exists_and_if_not_mkdir(out_dir) 

256 

257 chain_file = os.path.join(out_dir, "chain.dat") 

258 sampler_file = os.path.join(out_dir, "sampler.pickle") 

259 chain_template = ( 

260 "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n" 

261 ) 

262 

263 CheckpointInfo = namedtuple( 

264 "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"] 

265 ) 

266 

267 checkpoint_info = CheckpointInfo( 

268 sampler_file=sampler_file, 

269 chain_file=chain_file, 

270 chain_template=chain_template, 

271 ) 

272 

273 return checkpoint_info 

274 

275 @property 

276 def sampler_chain(self): 

277 nsteps = self._previous_iterations 

278 return self.sampler.chain[:, :nsteps, :] 

279 

280 def write_current_state(self): 

281 """ 

282 Writes a pickle file of the sampler to disk using dill 

283 

284 Overwrites the stored sampler chain with one that is truncated 

285 to only the completed steps 

286 """ 

287 logger.info( 

288 f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}" 

289 ) 

290 self.sampler._chain = self.sampler_chain 

291 _pool = self.sampler.pool 

292 self.sampler.pool = None 

293 safe_file_dump(self._sampler, self.checkpoint_info.sampler_file, "dill") 

294 self.sampler.pool = _pool 

295 

296 def _initialise_sampler(self): 

297 from emcee import EnsembleSampler 

298 

299 self._sampler = EnsembleSampler(**self.sampler_init_kwargs) 

300 self._init_chain_file() 

301 

302 @property 

303 def sampler(self): 

304 """Returns the emcee sampler object 

305 

306 If, already initialized, returns the stored _sampler value. Otherwise, 

307 first checks if there is a pickle file from which to load. If there is 

308 not, then initialize the sampler and set the initial random draw 

309 

310 """ 

311 if hasattr(self, "_sampler"): 

312 pass 

313 elif ( 

314 self.resume 

315 and os.path.isfile(self.checkpoint_info.sampler_file) 

316 and os.path.getsize(self.checkpoint_info.sampler_file) 

317 ): 

318 import dill 

319 

320 logger.info( 

321 f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}" 

322 ) 

323 with open(self.checkpoint_info.sampler_file, "rb") as f: 

324 self._sampler = dill.load(f) 

325 self._sampler.pool = self.pool 

326 self._set_pos0_for_resume() 

327 else: 

328 self._initialise_sampler() 

329 self._set_pos0() 

330 return self._sampler 

331 

332 def write_chains_to_file(self, sample): 

333 chain_file = self.checkpoint_info.chain_file 

334 temp_chain_file = chain_file + ".temp" 

335 if self.prerelease: 

336 points = np.hstack([sample.coords, sample.blobs]) 

337 else: 

338 points = np.hstack([sample[0], np.array(sample[3])]) 

339 data_to_write = "\n".join( 

340 self.checkpoint_info.chain_template.format(ii, *point) 

341 for ii, point in enumerate(points) 

342 ) 

343 with open(temp_chain_file, "w") as ff: 

344 ff.write(data_to_write) 

345 with open(temp_chain_file, "rb") as ftemp, open(chain_file, "ab") as fchain: 

346 shutil.copyfileobj(ftemp, fchain) 

347 os.remove(temp_chain_file) 

348 

349 @property 

350 def _previous_iterations(self): 

351 """Returns the number of iterations that the sampler has saved 

352 

353 This is used when loading in a sampler from a pickle file to figure out 

354 how much of the run has already been completed 

355 """ 

356 try: 

357 return len(self.sampler.blobs) 

358 except AttributeError: 

359 return 0 

360 

361 def _draw_pos0_from_prior(self): 

362 return np.array( 

363 [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] 

364 ) 

365 

366 @property 

367 def _pos0_shape(self): 

368 return (self.nwalkers, self.ndim) 

369 

370 def _set_pos0(self): 

371 if self.pos0 is not None: 

372 logger.debug("Using given initial positions for walkers") 

373 if isinstance(self.pos0, DataFrame): 

374 self.pos0 = self.pos0[self.search_parameter_keys].values 

375 elif type(self.pos0) in (list, np.ndarray): 

376 self.pos0 = np.squeeze(self.pos0) 

377 

378 if self.pos0.shape != self._pos0_shape: 

379 raise ValueError("Input pos0 should be of shape ndim, nwalkers") 

380 logger.debug("Checking input pos0") 

381 for draw in self.pos0: 

382 self.check_draw(draw) 

383 else: 

384 logger.debug("Generating initial walker positions from prior") 

385 self.pos0 = self._draw_pos0_from_prior() 

386 

387 def _set_pos0_for_resume(self): 

388 self.pos0 = self.sampler.chain[:, -1, :] 

389 

390 @signal_wrapper 

391 def run_sampler(self): 

392 self._setup_pool() 

393 from tqdm.auto import tqdm 

394 

395 sampler_function_kwargs = self.sampler_function_kwargs 

396 iterations = sampler_function_kwargs.pop("iterations") 

397 iterations -= self._previous_iterations 

398 

399 if self.prerelease: 

400 sampler_function_kwargs["initial_state"] = self.pos0 

401 else: 

402 sampler_function_kwargs["p0"] = self.pos0 

403 

404 # main iteration loop 

405 iterator = self.sampler.sample(iterations=iterations, **sampler_function_kwargs) 

406 if self.verbose: 

407 iterator = tqdm(iterator, total=iterations) 

408 for sample in iterator: 

409 self.write_chains_to_file(sample) 

410 if self.verbose: 

411 iterator.close() 

412 self.write_current_state() 

413 self._close_pool() 

414 

415 self.result.sampler_output = np.nan 

416 self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) 

417 self.print_nburn_logging_info() 

418 

419 self._generate_result() 

420 

421 self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape( 

422 (-1, self.ndim) 

423 ) 

424 self.result.walkers = self.sampler.chain 

425 return self.result 

426 

427 def _generate_result(self): 

428 self.result.nburn = self.nburn 

429 self.calc_likelihood_count() 

430 if self.result.nburn > self.nsteps: 

431 raise SamplerError( 

432 "The run has finished, but the chain is not burned in: " 

433 f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})." 

434 " Try increasing the number of steps." 

435 ) 

436 blobs = np.array(self.sampler.blobs) 

437 blobs_trimmed = blobs[self.nburn :, :, :].reshape((-1, 2)) 

438 log_likelihoods, log_priors = blobs_trimmed.T 

439 self.result.log_likelihood_evaluations = log_likelihoods 

440 self.result.log_prior_evaluations = log_priors 

441 self.result.log_evidence = np.nan 

442 self.result.log_evidence_err = np.nan