Coverage for bilby/core/sampler/__init__.py: 74%
126 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import inspect
3import sys
5from ..prior import DeltaFunction, PriorDict
6from ..utils import (
7 command_line_args,
8 env_package_list,
9 get_entry_points,
10 loaded_modules_dict,
11 logger,
12)
13from . import proposal
14from .base_sampler import Sampler, SamplingMarginalisedParameterError
17class ImplementedSamplers:
18 """Dictionary-like object that contains implemented samplers.
20 This class is singleton and only one instance can exist.
21 """
23 _instance = None
25 _samplers = get_entry_points("bilby.samplers")
27 def keys(self):
28 """Iterator of available samplers by name.
30 Reduces the list to its simplest. This includes removing the 'bilby.'
31 prefix from native samplers if a corresponding plugin is not available.
32 """
33 keys = []
34 for key in self._samplers.keys():
35 name = key.replace("bilby.", "")
36 if name in self._samplers.keys():
37 keys.append(key)
38 else:
39 keys.append(name)
40 return iter(keys)
42 def values(self):
43 """Iterator of sampler classes.
45 Note: the classes need to loaded using :code:`.load()` before being
46 called.
47 """
48 return iter(self._samplers.values())
50 def items(self):
51 """Iterator of tuples containing keys (sampler names) and classes.
53 Note: the classes need to loaded using :code:`.load()` before being
54 called.
55 """
56 return iter(((k, v) for k, v in zip(self.keys(), self.values())))
58 def valid_keys(self):
59 """All valid keys including bilby.<sampler name>."""
60 keys = set(self._samplers.keys())
61 return iter(keys.union({k.replace("bilby.", "") for k in keys}))
63 def __getitem__(self, key):
64 if key in self._samplers:
65 return self._samplers[key]
66 elif f"bilby.{key}" in self._samplers:
67 return self._samplers[f"bilby.{key}"]
68 else:
69 raise ValueError(
70 f"Sampler {key} is not implemented! "
71 f"Available samplers are: {list(self.keys())}"
72 )
74 def __contains__(self, value):
75 return value in self.valid_keys()
77 def __new__(cls):
78 if cls._instance is None:
79 cls._instance = super().__new__(cls)
80 return cls._instance
83IMPLEMENTED_SAMPLERS = ImplementedSamplers()
86def get_implemented_samplers():
87 """Get a list of the names of the implemented samplers.
89 This includes natively supported samplers (e.g. dynesty) and any additional
90 samplers that are supported through the sampler plugins.
92 Returns
93 -------
94 list
95 The list of implemented samplers.
96 """
97 return list(IMPLEMENTED_SAMPLERS.keys())
100def get_sampler_class(sampler):
101 """Get the class for a sampler from its name.
103 This includes natively supported samplers (e.g. dynesty) and any additional
104 samplers that are supported through the sampler plugins.
106 Parameters
107 ----------
108 sampler : str
109 The name of the sampler.
111 Returns
112 -------
113 Sampler
114 The sampler class.
116 Raises
117 ------
118 ValueError
119 Raised if the sampler is not implemented.
120 """
121 return IMPLEMENTED_SAMPLERS[sampler.lower()].load()
124if command_line_args.sampler_help:
125 sampler = command_line_args.sampler_help
126 if sampler in IMPLEMENTED_SAMPLERS:
127 sampler_class = IMPLEMENTED_SAMPLERS[sampler].load()
128 print(f'Help for sampler "{sampler}":')
129 print(sampler_class.__doc__)
130 else:
131 if sampler == "None":
132 print(
133 "For help with a specific sampler, call sampler-help with "
134 "the name of the sampler"
135 )
136 else:
137 print(f"Requested sampler {sampler} not implemented")
138 print(f"Available samplers = {get_implemented_samplers()}")
140 sys.exit()
143def run_sampler(
144 likelihood,
145 priors=None,
146 label="label",
147 outdir="outdir",
148 sampler="dynesty",
149 use_ratio=None,
150 injection_parameters=None,
151 conversion_function=None,
152 plot=False,
153 default_priors_file=None,
154 clean=None,
155 meta_data=None,
156 save=True,
157 gzip=False,
158 result_class=None,
159 npool=1,
160 **kwargs,
161):
162 """
163 The primary interface to easy parameter estimation
165 Parameters
166 ==========
167 likelihood: `bilby.Likelihood`
168 A `Likelihood` instance
169 priors: `bilby.PriorDict`
170 A PriorDict/dictionary of the priors for each parameter - missing
171 parameters will use default priors, if None, all priors will be default
172 label: str
173 Name for the run, used in output files
174 outdir: str
175 A string used in defining output files
176 sampler: str, Sampler
177 The name of the sampler to use - see
178 `bilby.sampler.get_implemented_samplers()` for a list of available
179 samplers.
180 Alternatively a Sampler object can be passed
181 use_ratio: bool (False)
182 If True, use the likelihood's log_likelihood_ratio, rather than just
183 the log_likelihood.
184 injection_parameters: dict
185 A dictionary of injection parameters used in creating the data (if
186 using simulated data). Appended to the result object and saved.
187 plot: bool
188 If true, generate a corner plot and, if applicable diagnostic plots
189 conversion_function: function, optional
190 Function to apply to posterior to generate additional parameters.
191 default_priors_file: str
192 If given, a file containing the default priors; otherwise defaults to
193 the bilby defaults for a binary black hole.
194 clean: bool
195 If given, override the command line interface `clean` option.
196 meta_data: dict
197 If given, adds the key-value pairs to the 'results' object before
198 saving. For example, if `meta_data={dtype: 'signal'}`. Warning: in case
199 of conflict with keys saved by bilby, the meta_data keys will be
200 overwritten.
201 save: bool, str
202 If true, save the priors and results to disk.
203 If hdf5, save as an hdf5 file instead of json.
204 If pickle or pkl, save as an pickle file instead of json.
205 gzip: bool
206 If true, and save is true, gzip the saved results file.
207 result_class: bilby.core.result.Result, or child of
208 The result class to use. By default, `bilby.core.result.Result` is used,
209 but objects which inherit from this class can be given providing
210 additional methods.
211 npool: int
212 An integer specifying the available CPUs to create pool objects for
213 parallelization.
214 **kwargs:
215 All kwargs are passed directly to the samplers `run` function
217 Returns
218 =======
219 result: bilby.core.result.Result
220 An object containing the results
221 """
223 logger.info(f"Running for label '{label}', output will be saved to '{outdir}'")
225 if clean:
226 command_line_args.clean = clean
227 if command_line_args.clean:
228 kwargs["resume"] = False
230 from . import IMPLEMENTED_SAMPLERS
232 if priors is None:
233 priors = dict()
235 _check_marginalized_parameters_not_sampled(likelihood, priors)
237 if type(priors) == dict:
238 priors = PriorDict(priors)
239 elif isinstance(priors, PriorDict):
240 pass
241 else:
242 raise ValueError("Input priors not understood should be dict or PriorDict")
244 priors.fill_priors(likelihood, default_priors_file=default_priors_file)
246 # Generate the meta-data if not given and append the likelihood meta_data
247 if meta_data is None:
248 meta_data = dict()
249 likelihood.label = label
250 likelihood.outdir = outdir
251 meta_data["likelihood"] = likelihood.meta_data
252 meta_data["loaded_modules"] = loaded_modules_dict()
253 meta_data["environment_packages"] = env_package_list(as_dataframe=True)
255 if command_line_args.bilby_zero_likelihood_mode:
256 from bilby.core.likelihood import ZeroLikelihood
258 likelihood = ZeroLikelihood(likelihood)
260 if isinstance(sampler, Sampler):
261 pass
262 elif isinstance(sampler, str):
263 sampler_class = get_sampler_class(sampler)
264 sampler = sampler_class(
265 likelihood,
266 priors=priors,
267 outdir=outdir,
268 label=label,
269 injection_parameters=injection_parameters,
270 meta_data=meta_data,
271 use_ratio=use_ratio,
272 plot=plot,
273 result_class=result_class,
274 npool=npool,
275 **kwargs,
276 )
277 elif inspect.isclass(sampler):
278 sampler = sampler.__init__(
279 likelihood,
280 priors=priors,
281 outdir=outdir,
282 label=label,
283 use_ratio=use_ratio,
284 plot=plot,
285 injection_parameters=injection_parameters,
286 meta_data=meta_data,
287 npool=npool,
288 **kwargs,
289 )
290 else:
291 raise ValueError(
292 "Provided sampler should be a Sampler object or name of a known "
293 f"sampler: {get_implemented_samplers()}."
294 )
296 if sampler.cached_result:
297 logger.warning("Using cached result")
298 result = sampler.cached_result
299 else:
300 # Run the sampler
301 start_time = datetime.datetime.now()
302 if command_line_args.bilby_test_mode:
303 result = sampler._run_test()
304 else:
305 result = sampler.run_sampler()
306 end_time = datetime.datetime.now()
308 # Some samplers calculate the sampling time internally
309 if result.sampling_time is None:
310 result.sampling_time = end_time - start_time
311 elif isinstance(result.sampling_time, (float, int)):
312 result.sampling_time = datetime.timedelta(result.sampling_time)
314 logger.info(f"Sampling time: {result.sampling_time}")
315 # Convert sampling time into seconds
316 result.sampling_time = result.sampling_time.total_seconds()
318 if sampler.use_ratio:
319 result.log_noise_evidence = likelihood.noise_log_likelihood()
320 result.log_bayes_factor = result.log_evidence
321 result.log_evidence = result.log_bayes_factor + result.log_noise_evidence
322 else:
323 result.log_noise_evidence = likelihood.noise_log_likelihood()
324 result.log_bayes_factor = result.log_evidence - result.log_noise_evidence
326 if None not in [result.injection_parameters, conversion_function]:
327 result.injection_parameters = conversion_function(
328 result.injection_parameters
329 )
331 # Initial save of the sampler in case of failure in samples_to_posterior
332 if save:
333 result.save_to_file(extension=save, gzip=gzip, outdir=outdir)
335 if None not in [result.injection_parameters, conversion_function]:
336 result.injection_parameters = conversion_function(result.injection_parameters)
338 # Check if the posterior has already been created
339 if getattr(result, "_posterior", None) is None:
340 result.samples_to_posterior(
341 likelihood=likelihood,
342 priors=result.priors,
343 conversion_function=conversion_function,
344 npool=npool,
345 )
347 if save:
348 # The overwrite here ensures we overwrite the initially stored data
349 result.save_to_file(overwrite=True, extension=save, gzip=gzip, outdir=outdir)
351 if plot:
352 result.plot_corner()
353 logger.info(f"Summary of results:\n{result}")
354 return result
357def _check_marginalized_parameters_not_sampled(likelihood, priors):
358 for key in likelihood.marginalized_parameters:
359 if key in priors:
360 if not isinstance(priors[key], (float, DeltaFunction)):
361 raise SamplingMarginalisedParameterError(
362 f"Likelihood is {key} marginalized but you are trying to sample in {key}. "
363 )