Coverage for bilby/core/prior/base.py: 89%
214 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1from importlib import import_module
2import json
3import os
4import re
6import numpy as np
7import scipy.stats
8from scipy.interpolate import interp1d
10from ..utils import (
11 infer_args_from_method,
12 BilbyJsonEncoder,
13 decode_bilby_json,
14 logger,
15 get_dict_with_properties,
16)
19class Prior(object):
20 _default_latex_labels = {}
22 def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf,
23 maximum=np.inf, check_range_nonzero=True, boundary=None):
24 """ Implements a Prior object
26 Parameters
27 ==========
28 name: str, optional
29 Name associated with prior.
30 latex_label: str, optional
31 Latex label associated with prior, used for plotting.
32 unit: str, optional
33 If given, a Latex string describing the units of the parameter.
34 minimum: float, optional
35 Minimum of the domain, default=-np.inf
36 maximum: float, optional
37 Maximum of the domain, default=np.inf
38 check_range_nonzero: boolean, optional
39 If True, checks that the prior range is non-zero
40 boundary: str, optional
41 The boundary condition of the prior, can be 'periodic', 'reflective'
42 Currently implemented in cpnest, dynesty and pymultinest.
43 """
44 if check_range_nonzero and maximum <= minimum:
45 raise ValueError(
46 "maximum {} <= minimum {} for {} prior on {}".format(
47 maximum, minimum, type(self).__name__, name
48 )
49 )
50 self.name = name
51 self.latex_label = latex_label
52 self.unit = unit
53 self.minimum = minimum
54 self.maximum = maximum
55 self.check_range_nonzero = check_range_nonzero
56 self.least_recently_sampled = None
57 self.boundary = boundary
58 self._is_fixed = False
60 def __call__(self):
61 """Overrides the __call__ special method. Calls the sample method.
63 Returns
64 =======
65 float: The return value of the sample method.
66 """
67 return self.sample()
69 def __eq__(self, other):
70 """
71 Test equality of two prior objects.
73 Returns true iff:
75 - The class of the two priors are the same
76 - Both priors have the same keys in the __dict__ attribute
77 - The instantiation arguments match
79 We don't check that all entries the the __dict__ attribute
80 are equal as some attributes are variable for conditional
81 priors.
83 Parameters
84 ==========
85 other: Prior
86 The prior to compare with
88 Returns
89 =======
90 bool
91 Whether the priors are equivalent
93 Notes
94 =====
95 A special case is made for :code `scipy.stats.beta`: instances.
96 It may be possible to remove this as we now only check instantiation
97 arguments.
99 """
100 if self.__class__ != other.__class__:
101 return False
102 if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()):
103 return False
104 this_dict = self.get_instantiation_dict()
105 other_dict = other.get_instantiation_dict()
106 for key in this_dict:
107 if key == "least_recently_sampled":
108 continue
109 if isinstance(this_dict[key], np.ndarray):
110 if not np.array_equal(this_dict[key], other_dict[key]):
111 return False
112 elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))):
113 continue
114 else:
115 if not this_dict[key] == other_dict[key]:
116 return False
117 return True
119 def sample(self, size=None):
120 """Draw a sample from the prior
122 Parameters
123 ==========
124 size: int or tuple of ints, optional
125 See numpy.random.uniform docs
127 Returns
128 =======
129 float: A random number between 0 and 1, rescaled to match the distribution of this Prior
131 """
132 from ..utils.random import rng
134 self.least_recently_sampled = self.rescale(rng.uniform(0, 1, size))
135 return self.least_recently_sampled
137 def rescale(self, val):
138 """
139 'Rescale' a sample from the unit line element to the prior.
141 This should be overwritten by each subclass.
143 Parameters
144 ==========
145 val: Union[float, int, array_like]
146 A random number between 0 and 1
148 Returns
149 =======
150 None
152 """
153 return None
155 def prob(self, val):
156 """Return the prior probability of val, this should be overwritten
158 Parameters
159 ==========
160 val: Union[float, int, array_like]
162 Returns
163 =======
164 np.nan
166 """
167 return np.nan
169 def cdf(self, val):
170 """ Generic method to calculate CDF, can be overwritten in subclass """
171 from scipy.integrate import cumulative_trapezoid
172 if np.any(np.isinf([self.minimum, self.maximum])):
173 raise ValueError(
174 "Unable to use the generic CDF calculation for priors with"
175 "infinite support")
176 x = np.linspace(self.minimum, self.maximum, 1000)
177 pdf = self.prob(x)
178 cdf = cumulative_trapezoid(pdf, x, initial=0)
179 interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False,
180 fill_value=(0, 1))
181 return interp(val)
183 def ln_prob(self, val):
184 """Return the prior ln probability of val, this should be overwritten
186 Parameters
187 ==========
188 val: Union[float, int, array_like]
190 Returns
191 =======
192 np.nan
194 """
195 with np.errstate(divide='ignore'):
196 return np.log(self.prob(val))
198 def is_in_prior_range(self, val):
199 """Returns True if val is in the prior boundaries, zero otherwise
201 Parameters
202 ==========
203 val: Union[float, int, array_like]
205 Returns
206 =======
207 np.nan
209 """
210 return (val >= self.minimum) & (val <= self.maximum)
212 def __repr__(self):
213 """Overrides the special method __repr__.
215 Returns a representation of this instance that resembles how it is instantiated.
216 Works correctly for all child classes
218 Returns
219 =======
220 str: A string representation of this instance
222 """
223 prior_name = self.__class__.__name__
224 prior_module = self.__class__.__module__
225 instantiation_dict = self.get_instantiation_dict()
226 args = ', '.join([f'{key}={repr(instantiation_dict[key])}' for key in instantiation_dict])
227 if "bilby.core.prior" in prior_module:
228 return f"{prior_name}({args})"
229 else:
230 return f"{prior_module}.{prior_name}({args})"
232 @property
233 def is_fixed(self):
234 """
235 Returns True if the prior is fixed and should not be used in the sampler. Does this by checking if this instance
236 is an instance of DeltaFunction.
239 Returns
240 =======
241 bool: Whether it's fixed or not!
243 """
244 return self._is_fixed
246 @property
247 def latex_label(self):
248 """Latex label that can be used for plots.
250 Draws from a set of default labels if no label is given
252 Returns
253 =======
254 str: A latex representation for this prior
256 """
257 return self.__latex_label
259 @latex_label.setter
260 def latex_label(self, latex_label=None):
261 if latex_label is None:
262 self.__latex_label = self.__default_latex_label
263 else:
264 self.__latex_label = latex_label
266 @property
267 def unit(self):
268 return self.__unit
270 @unit.setter
271 def unit(self, unit):
272 self.__unit = unit
274 @property
275 def latex_label_with_unit(self):
276 """ If a unit is specified, returns a string of the latex label and unit """
277 if self.unit is not None:
278 return "{} [{}]".format(self.latex_label, self.unit)
279 else:
280 return self.latex_label
282 @property
283 def minimum(self):
284 return self._minimum
286 @minimum.setter
287 def minimum(self, minimum):
288 self._minimum = minimum
290 @property
291 def maximum(self):
292 return self._maximum
294 @maximum.setter
295 def maximum(self, maximum):
296 self._maximum = maximum
298 @property
299 def width(self):
300 return self.maximum - self.minimum
302 def get_instantiation_dict(self):
303 subclass_args = infer_args_from_method(self.__init__)
304 dict_with_properties = get_dict_with_properties(self)
305 return {key: dict_with_properties[key] for key in subclass_args}
307 @property
308 def boundary(self):
309 return self._boundary
311 @boundary.setter
312 def boundary(self, boundary):
313 if boundary not in ['periodic', 'reflective', None]:
314 raise ValueError('{} is not a valid setting for prior boundaries'.format(boundary))
315 self._boundary = boundary
317 @property
318 def __default_latex_label(self):
319 if self.name in self._default_latex_labels.keys():
320 label = self._default_latex_labels[self.name]
321 else:
322 label = self.name
323 return label
325 def to_json(self):
326 return json.dumps(self, cls=BilbyJsonEncoder)
328 @classmethod
329 def from_json(cls, dct):
330 return decode_bilby_json(dct)
332 @classmethod
333 def from_repr(cls, string):
334 """Generate the prior from its __repr__"""
335 return cls._from_repr(string)
337 @classmethod
338 def _from_repr(cls, string):
339 subclass_args = infer_args_from_method(cls.__init__)
341 string = string.replace(' ', '')
342 kwargs = cls._split_repr(string)
343 for key in kwargs:
344 val = kwargs[key]
345 if key not in subclass_args and not hasattr(cls, "reference_params"):
346 raise AttributeError('Unknown argument {} for class {}'.format(
347 key, cls.__name__))
348 else:
349 kwargs[key] = cls._parse_argument_string(val)
350 if key in ["condition_func", "conversion_function"] and isinstance(kwargs[key], str):
351 if "." in kwargs[key]:
352 module = '.'.join(kwargs[key].split('.')[:-1])
353 name = kwargs[key].split('.')[-1]
354 else:
355 module = __name__
356 name = kwargs[key]
357 kwargs[key] = getattr(import_module(module), name)
358 return cls(**kwargs)
360 @classmethod
361 def _split_repr(cls, string):
362 subclass_args = infer_args_from_method(cls.__init__)
363 args = string.split(',')
364 remove = list()
365 for ii, key in enumerate(args):
366 if '(' in key:
367 jj = ii
368 while ')' not in args[jj]:
369 jj += 1
370 args[ii] = ','.join([args[ii], args[jj]]).strip()
371 remove.append(jj)
372 remove.reverse()
373 for ii in remove:
374 del args[ii]
375 kwargs = dict()
376 for ii, arg in enumerate(args):
377 if '=' not in arg:
378 logger.debug(
379 'Reading priors with non-keyword arguments is dangerous!')
380 key = subclass_args[ii]
381 val = arg
382 else:
383 split_arg = arg.split('=')
384 key = split_arg[0]
385 val = '='.join(split_arg[1:])
386 kwargs[key] = val
387 return kwargs
389 @classmethod
390 def _parse_argument_string(cls, val):
391 """
392 Parse a string into the appropriate type for prior reading.
394 Four tests are applied in the following order:
396 - If the string is 'None':
397 `None` is returned.
398 - Else If the string is a raw string, e.g., r'foo':
399 A stripped version of the string is returned, e.g., foo.
400 - Else If the string contains ', e.g., 'foo':
401 A stripped version of the string is returned, e.g., foo.
402 - Else If the string contains an open parenthesis, (:
403 The string is interpreted as a call to instantiate another prior
404 class, Bilby will attempt to recursively construct that prior,
405 e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs).
406 - Else If the string contains a ".":
407 It is treated as a path to a Python function and imported, e.g.,
408 "some_module.some_function" returns
409 :code:`import some_module; return some_module.some_function`
410 - Else:
411 Try to evaluate the string using `eval`. Only built-in functions
412 and numpy methods can be used, e.g., np.pi / 2, 1.57.
415 Parameters
416 ==========
417 val: str
418 The string version of the argument
420 Returns
421 =======
422 val: object
423 The parsed version of the argument.
425 Raises
426 ======
427 TypeError:
428 If val cannot be parsed as described above.
429 """
430 if val == 'None':
431 val = None
432 elif re.sub(r'\'.*\'', '', val) in ['r', 'u']:
433 val = val[2:-1]
434 elif val.startswith("'") and val.endswith("'"):
435 val = val.strip("'")
436 elif '(' in val and not val.startswith(("[", "{")):
437 other_cls = val.split('(')[0]
438 vals = '('.join(val.split('(')[1:])[:-1]
439 if "." in other_cls:
440 module = '.'.join(other_cls.split('.')[:-1])
441 other_cls = other_cls.split('.')[-1]
442 else:
443 module = __name__.replace('.' + os.path.basename(__file__).replace('.py', ''), '')
444 other_cls = getattr(import_module(module), other_cls)
445 val = other_cls.from_repr(vals)
446 else:
447 try:
448 val = eval(val, dict(), dict(np=np, inf=np.inf, pi=np.pi))
449 except NameError:
450 if "." in val:
451 module = '.'.join(val.split('.')[:-1])
452 func = val.split('.')[-1]
453 new_val = getattr(import_module(module), func, val)
454 if val == new_val:
455 raise TypeError(
456 "Cannot evaluate prior, "
457 f"failed to parse argument {val}"
458 )
459 else:
460 val = new_val
461 return val
464class Constraint(Prior):
466 def __init__(self, minimum, maximum, name=None, latex_label=None,
467 unit=None):
468 super(Constraint, self).__init__(minimum=minimum, maximum=maximum, name=name,
469 latex_label=latex_label, unit=unit)
470 self._is_fixed = True
472 def prob(self, val):
473 return (val > self.minimum) & (val < self.maximum)
476class PriorException(Exception):
477 """ General base class for all prior exceptions """