Coverage for bilby/core/sampler/__init__.py: 74%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import datetime 

2import inspect 

3import sys 

4 

5from ..prior import DeltaFunction, PriorDict 

6from ..utils import ( 

7 command_line_args, 

8 env_package_list, 

9 get_entry_points, 

10 loaded_modules_dict, 

11 logger, 

12) 

13from . import proposal 

14from .base_sampler import Sampler, SamplingMarginalisedParameterError 

15 

16 

17class ImplementedSamplers: 

18 """Dictionary-like object that contains implemented samplers. 

19 

20 This class is singleton and only one instance can exist. 

21 """ 

22 

23 _instance = None 

24 

25 _samplers = get_entry_points("bilby.samplers") 

26 

27 def keys(self): 

28 """Iterator of available samplers by name. 

29 

30 Reduces the list to its simplest. This includes removing the 'bilby.' 

31 prefix from native samplers if a corresponding plugin is not available. 

32 """ 

33 keys = [] 

34 for key in self._samplers.keys(): 

35 name = key.replace("bilby.", "") 

36 if name in self._samplers.keys(): 

37 keys.append(key) 

38 else: 

39 keys.append(name) 

40 return iter(keys) 

41 

42 def values(self): 

43 """Iterator of sampler classes. 

44 

45 Note: the classes need to loaded using :code:`.load()` before being 

46 called. 

47 """ 

48 return iter(self._samplers.values()) 

49 

50 def items(self): 

51 """Iterator of tuples containing keys (sampler names) and classes. 

52 

53 Note: the classes need to loaded using :code:`.load()` before being 

54 called. 

55 """ 

56 return iter(((k, v) for k, v in zip(self.keys(), self.values()))) 

57 

58 def valid_keys(self): 

59 """All valid keys including bilby.<sampler name>.""" 

60 keys = set(self._samplers.keys()) 

61 return iter(keys.union({k.replace("bilby.", "") for k in keys})) 

62 

63 def __getitem__(self, key): 

64 if key in self._samplers: 

65 return self._samplers[key] 

66 elif f"bilby.{key}" in self._samplers: 

67 return self._samplers[f"bilby.{key}"] 

68 else: 

69 raise ValueError( 

70 f"Sampler {key} is not implemented! " 

71 f"Available samplers are: {list(self.keys())}" 

72 ) 

73 

74 def __contains__(self, value): 

75 return value in self.valid_keys() 

76 

77 def __new__(cls): 

78 if cls._instance is None: 

79 cls._instance = super().__new__(cls) 

80 return cls._instance 

81 

82 

83IMPLEMENTED_SAMPLERS = ImplementedSamplers() 

84 

85 

86def get_implemented_samplers(): 

87 """Get a list of the names of the implemented samplers. 

88 

89 This includes natively supported samplers (e.g. dynesty) and any additional 

90 samplers that are supported through the sampler plugins. 

91 

92 Returns 

93 ------- 

94 list 

95 The list of implemented samplers. 

96 """ 

97 return list(IMPLEMENTED_SAMPLERS.keys()) 

98 

99 

100def get_sampler_class(sampler): 

101 """Get the class for a sampler from its name. 

102 

103 This includes natively supported samplers (e.g. dynesty) and any additional 

104 samplers that are supported through the sampler plugins. 

105 

106 Parameters 

107 ---------- 

108 sampler : str 

109 The name of the sampler. 

110 

111 Returns 

112 ------- 

113 Sampler 

114 The sampler class. 

115 

116 Raises 

117 ------ 

118 ValueError 

119 Raised if the sampler is not implemented. 

120 """ 

121 return IMPLEMENTED_SAMPLERS[sampler.lower()].load() 

122 

123 

124if command_line_args.sampler_help: 

125 sampler = command_line_args.sampler_help 

126 if sampler in IMPLEMENTED_SAMPLERS: 

127 sampler_class = IMPLEMENTED_SAMPLERS[sampler].load() 

128 print(f'Help for sampler "{sampler}":') 

129 print(sampler_class.__doc__) 

130 else: 

131 if sampler == "None": 

132 print( 

133 "For help with a specific sampler, call sampler-help with " 

134 "the name of the sampler" 

135 ) 

136 else: 

137 print(f"Requested sampler {sampler} not implemented") 

138 print(f"Available samplers = {get_implemented_samplers()}") 

139 

140 sys.exit() 

141 

142 

143def run_sampler( 

144 likelihood, 

145 priors=None, 

146 label="label", 

147 outdir="outdir", 

148 sampler="dynesty", 

149 use_ratio=None, 

150 injection_parameters=None, 

151 conversion_function=None, 

152 plot=False, 

153 default_priors_file=None, 

154 clean=None, 

155 meta_data=None, 

156 save=True, 

157 gzip=False, 

158 result_class=None, 

159 npool=1, 

160 **kwargs, 

161): 

162 """ 

163 The primary interface to easy parameter estimation 

164 

165 Parameters 

166 ========== 

167 likelihood: `bilby.Likelihood` 

168 A `Likelihood` instance 

169 priors: `bilby.PriorDict` 

170 A PriorDict/dictionary of the priors for each parameter - missing 

171 parameters will use default priors, if None, all priors will be default 

172 label: str 

173 Name for the run, used in output files 

174 outdir: str 

175 A string used in defining output files 

176 sampler: str, Sampler 

177 The name of the sampler to use - see 

178 `bilby.sampler.get_implemented_samplers()` for a list of available 

179 samplers. 

180 Alternatively a Sampler object can be passed 

181 use_ratio: bool (False) 

182 If True, use the likelihood's log_likelihood_ratio, rather than just 

183 the log_likelihood. 

184 injection_parameters: dict 

185 A dictionary of injection parameters used in creating the data (if 

186 using simulated data). Appended to the result object and saved. 

187 plot: bool 

188 If true, generate a corner plot and, if applicable diagnostic plots 

189 conversion_function: function, optional 

190 Function to apply to posterior to generate additional parameters. 

191 default_priors_file: str 

192 If given, a file containing the default priors; otherwise defaults to 

193 the bilby defaults for a binary black hole. 

194 clean: bool 

195 If given, override the command line interface `clean` option. 

196 meta_data: dict 

197 If given, adds the key-value pairs to the 'results' object before 

198 saving. For example, if `meta_data={dtype: 'signal'}`. Warning: in case 

199 of conflict with keys saved by bilby, the meta_data keys will be 

200 overwritten. 

201 save: bool, str 

202 If true, save the priors and results to disk. 

203 If hdf5, save as an hdf5 file instead of json. 

204 If pickle or pkl, save as an pickle file instead of json. 

205 gzip: bool 

206 If true, and save is true, gzip the saved results file. 

207 result_class: bilby.core.result.Result, or child of 

208 The result class to use. By default, `bilby.core.result.Result` is used, 

209 but objects which inherit from this class can be given providing 

210 additional methods. 

211 npool: int 

212 An integer specifying the available CPUs to create pool objects for 

213 parallelization. 

214 **kwargs: 

215 All kwargs are passed directly to the samplers `run` function 

216 

217 Returns 

218 ======= 

219 result: bilby.core.result.Result 

220 An object containing the results 

221 """ 

222 

223 logger.info(f"Running for label '{label}', output will be saved to '{outdir}'") 

224 

225 if clean: 

226 command_line_args.clean = clean 

227 if command_line_args.clean: 

228 kwargs["resume"] = False 

229 

230 from . import IMPLEMENTED_SAMPLERS 

231 

232 if priors is None: 

233 priors = dict() 

234 

235 _check_marginalized_parameters_not_sampled(likelihood, priors) 

236 

237 if type(priors) == dict: 

238 priors = PriorDict(priors) 

239 elif isinstance(priors, PriorDict): 

240 pass 

241 else: 

242 raise ValueError("Input priors not understood should be dict or PriorDict") 

243 

244 priors.fill_priors(likelihood, default_priors_file=default_priors_file) 

245 

246 # Generate the meta-data if not given and append the likelihood meta_data 

247 if meta_data is None: 

248 meta_data = dict() 

249 likelihood.label = label 

250 likelihood.outdir = outdir 

251 meta_data["likelihood"] = likelihood.meta_data 

252 meta_data["loaded_modules"] = loaded_modules_dict() 

253 meta_data["environment_packages"] = env_package_list(as_dataframe=True) 

254 

255 if command_line_args.bilby_zero_likelihood_mode: 

256 from bilby.core.likelihood import ZeroLikelihood 

257 

258 likelihood = ZeroLikelihood(likelihood) 

259 

260 if isinstance(sampler, Sampler): 

261 pass 

262 elif isinstance(sampler, str): 

263 sampler_class = get_sampler_class(sampler) 

264 sampler = sampler_class( 

265 likelihood, 

266 priors=priors, 

267 outdir=outdir, 

268 label=label, 

269 injection_parameters=injection_parameters, 

270 meta_data=meta_data, 

271 use_ratio=use_ratio, 

272 plot=plot, 

273 result_class=result_class, 

274 npool=npool, 

275 **kwargs, 

276 ) 

277 elif inspect.isclass(sampler): 

278 sampler = sampler.__init__( 

279 likelihood, 

280 priors=priors, 

281 outdir=outdir, 

282 label=label, 

283 use_ratio=use_ratio, 

284 plot=plot, 

285 injection_parameters=injection_parameters, 

286 meta_data=meta_data, 

287 npool=npool, 

288 **kwargs, 

289 ) 

290 else: 

291 raise ValueError( 

292 "Provided sampler should be a Sampler object or name of a known " 

293 f"sampler: {get_implemented_samplers()}." 

294 ) 

295 

296 if sampler.cached_result: 

297 logger.warning("Using cached result") 

298 result = sampler.cached_result 

299 else: 

300 # Run the sampler 

301 start_time = datetime.datetime.now() 

302 if command_line_args.bilby_test_mode: 

303 result = sampler._run_test() 

304 else: 

305 result = sampler.run_sampler() 

306 end_time = datetime.datetime.now() 

307 

308 # Some samplers calculate the sampling time internally 

309 if result.sampling_time is None: 

310 result.sampling_time = end_time - start_time 

311 elif isinstance(result.sampling_time, (float, int)): 

312 result.sampling_time = datetime.timedelta(result.sampling_time) 

313 

314 logger.info(f"Sampling time: {result.sampling_time}") 

315 # Convert sampling time into seconds 

316 result.sampling_time = result.sampling_time.total_seconds() 

317 

318 if sampler.use_ratio: 

319 result.log_noise_evidence = likelihood.noise_log_likelihood() 

320 result.log_bayes_factor = result.log_evidence 

321 result.log_evidence = result.log_bayes_factor + result.log_noise_evidence 

322 else: 

323 result.log_noise_evidence = likelihood.noise_log_likelihood() 

324 result.log_bayes_factor = result.log_evidence - result.log_noise_evidence 

325 

326 if None not in [result.injection_parameters, conversion_function]: 

327 result.injection_parameters = conversion_function( 

328 result.injection_parameters 

329 ) 

330 

331 # Initial save of the sampler in case of failure in samples_to_posterior 

332 if save: 

333 result.save_to_file(extension=save, gzip=gzip, outdir=outdir) 

334 

335 if None not in [result.injection_parameters, conversion_function]: 

336 result.injection_parameters = conversion_function(result.injection_parameters) 

337 

338 # Check if the posterior has already been created 

339 if getattr(result, "_posterior", None) is None: 

340 result.samples_to_posterior( 

341 likelihood=likelihood, 

342 priors=result.priors, 

343 conversion_function=conversion_function, 

344 npool=npool, 

345 ) 

346 

347 if save: 

348 # The overwrite here ensures we overwrite the initially stored data 

349 result.save_to_file(overwrite=True, extension=save, gzip=gzip, outdir=outdir) 

350 

351 if plot: 

352 result.plot_corner() 

353 logger.info(f"Summary of results:\n{result}") 

354 return result 

355 

356 

357def _check_marginalized_parameters_not_sampled(likelihood, priors): 

358 for key in likelihood.marginalized_parameters: 

359 if key in priors: 

360 if not isinstance(priors[key], (float, DeltaFunction)): 

361 raise SamplingMarginalisedParameterError( 

362 f"Likelihood is {key} marginalized but you are trying to sample in {key}. " 

363 )