Coverage for bilby/core/prior/conditional.py: 96%
129 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1from .base import Prior, PriorException
2from .interpolated import Interped
3from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \
4 SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \
5 LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac
6from ..utils import infer_args_from_method, infer_parameters_from_function
9def conditional_prior_factory(prior_class):
10 class ConditionalPrior(prior_class):
11 def __init__(self, condition_func, name=None, latex_label=None, unit=None,
12 boundary=None, **reference_params):
13 """
15 Parameters
16 ==========
17 condition_func: func
18 Functional form of the condition for this prior. The first function argument
19 has to be a dictionary for the `reference_params` (see below). The following
20 arguments are the required variables that are required before we can draw this
21 prior.
22 It needs to return a dictionary with the modified values for the
23 `reference_params` that are being used in the next draw.
24 For example if we have a Uniform prior for `x` depending on a different variable `y`
25 `p(x|y)` with the boundaries linearly depending on y, then this
26 could have the following form:
28 .. code-block:: python
30 def condition_func(reference_params, y):
31 return dict(
32 minimum=reference_params['minimum'] + y,
33 maximum=reference_params['maximum'] + y
34 )
36 name: str, optional
37 See superclass
38 latex_label: str, optional
39 See superclass
40 unit: str, optional
41 See superclass
42 boundary: str, optional
43 See superclass
44 reference_params:
45 Initial values for attributes such as `minimum`, `maximum`.
46 This differs on the `prior_class`, for example for the Gaussian
47 prior this is `mu` and `sigma`.
48 """
49 if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__):
50 super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
51 unit=unit, boundary=boundary, **reference_params)
52 else:
53 super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
54 unit=unit, **reference_params)
56 self._required_variables = None
57 self.condition_func = condition_func
58 self._reference_params = reference_params
59 self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__)
60 self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__)
62 def sample(self, size=None, **required_variables):
63 """Draw a sample from the prior
65 Parameters
66 ==========
67 size: int or tuple of ints, optional
68 See superclass
69 required_variables:
70 Any required variables that this prior depends on
72 Returns
73 =======
74 float: See superclass
76 """
77 from ..utils.random import rng
79 self.least_recently_sampled = self.rescale(rng.uniform(0, 1, size), **required_variables)
80 return self.least_recently_sampled
82 def rescale(self, val, **required_variables):
83 """
84 'Rescale' a sample from the unit line element to the prior.
86 Parameters
87 ==========
88 val: Union[float, int, array_like]
89 See superclass
90 required_variables:
91 Any required variables that this prior depends on
94 """
95 self.update_conditions(**required_variables)
96 return super(ConditionalPrior, self).rescale(val)
98 def prob(self, val, **required_variables):
99 """Return the prior probability of val.
101 Parameters
102 ==========
103 val: Union[float, int, array_like]
104 See superclass
105 required_variables:
106 Any required variables that this prior depends on
109 Returns
110 =======
111 float: Prior probability of val
112 """
113 self.update_conditions(**required_variables)
114 return super(ConditionalPrior, self).prob(val)
116 def ln_prob(self, val, **required_variables):
117 """Return the natural log prior probability of val.
119 Parameters
120 ==========
121 val: Union[float, int, array_like]
122 See superclass
123 required_variables:
124 Any required variables that this prior depends on
127 Returns
128 =======
129 float: Natural log prior probability of val
130 """
131 self.update_conditions(**required_variables)
132 return super(ConditionalPrior, self).ln_prob(val)
134 def cdf(self, val, **required_variables):
135 """Return the cdf of val.
137 Parameters
138 ==========
139 val: Union[float, int, array_like]
140 See superclass
141 required_variables:
142 Any required variables that this prior depends on
145 Returns
146 =======
147 float: CDF of val
148 """
149 self.update_conditions(**required_variables)
150 return super(ConditionalPrior, self).cdf(val)
152 def update_conditions(self, **required_variables):
153 """
154 This method updates the conditional parameters (depending on the parent class
155 this could be e.g. `minimum`, `maximum`, `mu`, `sigma`, etc.) of this prior
156 class depending on the required variables it depends on.
158 If no variables are given, the most recently used conditional parameters are kept
160 Parameters
161 ==========
162 required_variables:
163 Any required variables that this prior depends on. If none are given,
164 self.reference_params will be used.
166 """
167 if sorted(list(required_variables)) == sorted(self.required_variables):
168 parameters = self.condition_func(self.reference_params.copy(), **required_variables)
169 for key, value in parameters.items():
170 setattr(self, key, value)
171 elif len(required_variables) == 0:
172 return
173 else:
174 raise IllegalRequiredVariablesException("Expected kwargs for {}. Got kwargs for {} instead."
175 .format(self.required_variables,
176 list(required_variables.keys())))
178 @property
179 def reference_params(self):
180 """
181 Initial values for attributes such as `minimum`, `maximum`.
182 This depends on the `prior_class`, for example for the Gaussian
183 prior this is `mu` and `sigma`. This is read-only.
184 """
185 return self._reference_params
187 @property
188 def condition_func(self):
189 return self._condition_func
191 @condition_func.setter
192 def condition_func(self, condition_func):
193 if condition_func is None:
194 self._condition_func = lambda reference_params: reference_params
195 else:
196 self._condition_func = condition_func
197 self._required_variables = infer_parameters_from_function(self.condition_func)
199 @property
200 def required_variables(self):
201 """ The required variables to pass into the condition function. """
202 return self._required_variables
204 def get_instantiation_dict(self):
205 instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict()
206 for key, value in self.reference_params.items():
207 instantiation_dict[key] = value
208 return instantiation_dict
210 def reset_to_reference_parameters(self):
211 """
212 Reset the object attributes to match the original reference parameters
213 """
214 for key, value in self.reference_params.items():
215 setattr(self, key, value)
217 def __repr__(self):
218 """Overrides the special method __repr__.
220 Returns a representation of this instance that resembles how it is instantiated.
221 Works correctly for all child classes
223 Returns
224 =======
225 str: A string representation of this instance
227 """
228 prior_name = self.__class__.__name__
229 instantiation_dict = self.get_instantiation_dict()
230 instantiation_dict["condition_func"] = ".".join([
231 instantiation_dict["condition_func"].__module__,
232 instantiation_dict["condition_func"].__name__
233 ])
234 args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
235 for key in instantiation_dict])
236 return "{}({})".format(prior_name, args)
238 return ConditionalPrior
241class ConditionalBasePrior(conditional_prior_factory(Prior)):
242 pass
245class ConditionalUniform(conditional_prior_factory(Uniform)):
246 pass
249class ConditionalDeltaFunction(conditional_prior_factory(DeltaFunction)):
250 pass
253class ConditionalPowerLaw(conditional_prior_factory(PowerLaw)):
254 pass
257class ConditionalGaussian(conditional_prior_factory(Gaussian)):
258 pass
261class ConditionalLogUniform(conditional_prior_factory(LogUniform)):
262 pass
265class ConditionalSymmetricLogUniform(conditional_prior_factory(SymmetricLogUniform)):
266 pass
269class ConditionalCosine(conditional_prior_factory(Cosine)):
270 pass
273class ConditionalSine(conditional_prior_factory(Sine)):
274 pass
277class ConditionalTruncatedGaussian(conditional_prior_factory(TruncatedGaussian)):
278 pass
281class ConditionalHalfGaussian(conditional_prior_factory(HalfGaussian)):
282 pass
285class ConditionalLogNormal(conditional_prior_factory(LogNormal)):
286 pass
289class ConditionalExponential(conditional_prior_factory(Exponential)):
290 pass
293class ConditionalStudentT(conditional_prior_factory(StudentT)):
294 pass
297class ConditionalBeta(conditional_prior_factory(Beta)):
298 pass
301class ConditionalLogistic(conditional_prior_factory(Logistic)):
302 pass
305class ConditionalCauchy(conditional_prior_factory(Cauchy)):
306 pass
309class ConditionalGamma(conditional_prior_factory(Gamma)):
310 pass
313class ConditionalChiSquared(conditional_prior_factory(ChiSquared)):
314 pass
317class ConditionalFermiDirac(conditional_prior_factory(FermiDirac)):
318 pass
321class ConditionalInterped(conditional_prior_factory(Interped)):
322 pass
325class DirichletElement(ConditionalBeta):
326 r"""
327 Single element in a dirichlet distribution
329 The probability scales as
331 .. math::
332 p(x_n) \propto (x_\max - x_n)^{(N - n - 2)}
334 for :math:`x_n < x_\max`, where :math:`x_\max` is the sum of :math:`x_i`
335 for :math:`i < n`
337 Examples
338 ========
339 n_dimensions = 1:
341 .. math::
342 p(x_0) \propto 1 ; 0 < x_0 < 1
344 n_dimensions = 2:
345 .. math::
346 p(x_0) &\propto (1 - x_0) ; 0 < x_0 < 1
347 p(x_1) &\propto 1 ; 0 < x_1 < 1
349 Parameters
350 ==========
351 order: int
352 Order of this element of the dirichlet distribution.
353 n_dimensions: int
354 Total number of elements of the dirichlet distribution
355 label: str
356 Label for the dirichlet distribution.
357 This should be the same for all elements.
359 """
361 def __init__(self, order, n_dimensions, label):
362 """ """
363 super(DirichletElement, self).__init__(
364 minimum=0, maximum=1, alpha=1, beta=n_dimensions - order - 1,
365 name=label + str(order),
366 condition_func=self.dirichlet_condition
367 )
368 self.label = label
369 self.n_dimensions = n_dimensions
370 self.order = order
371 self._required_variables = [
372 label + str(ii) for ii in range(order)
373 ]
374 self.__class__.__name__ = 'DirichletElement'
375 self.__class__.__qualname__ = 'DirichletElement'
377 def dirichlet_condition(self, reference_parms, **kwargs):
378 remaining = 1 - sum(
379 [kwargs[self.label + str(ii)] for ii in range(self.order)]
380 )
381 return dict(minimum=reference_parms["minimum"], maximum=remaining)
383 def __repr__(self):
384 return Prior.__repr__(self)
386 def get_instantiation_dict(self):
387 return Prior.get_instantiation_dict(self)
390class ConditionalPriorException(PriorException):
391 """ General base class for all conditional prior exceptions """
394class IllegalRequiredVariablesException(ConditionalPriorException):
395 """ Exception class for exceptions relating to handling the required variables. """