Coverage for bilby/hyper/model.py: 73%
30 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1from ..core.utils import infer_args_from_function_except_n_args
4class Model:
5 r"""
6 Population model that combines a set of factorizable models.
8 This should take population parameters and return the probability.
10 .. math::
12 p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda)
13 """
15 def __init__(self, model_functions=None, cache=True):
16 """
17 Parameters
18 ==========
19 model_functions: list
20 List of callables to compute the probability.
21 If this includes classes, the :code:`__call__`: method
22 should return the probability.
23 The requires variables are chosen at run time based on either
24 inspection or querying a :code:`variable_names` attribute.
25 cache: bool
26 Whether to cache the value returned by the model functions,
27 default=:code:`True`. The caching only looks at the parameters
28 not the data, so should be used with caution. The caching also
29 breaks :code:`jax` JIT compilation.
30 """
31 self.models = model_functions
32 self.cache = cache
33 self._cached_parameters = {model: None for model in self.models}
34 self._cached_probability = {model: None for model in self.models}
36 self.parameters = dict()
38 def prob(self, data, **kwargs):
39 """
40 Compute the total population probability for the provided data given
41 the keyword arguments.
43 Parameters
44 ==========
45 data: dict
46 Dictionary containing the points at which to evaluate the
47 population model.
48 kwargs: dict
49 The population parameters. These cannot include any of
50 :code:`["dataset", "data", "self", "cls"]` unless the
51 :code:`variable_names` attribute is available for the relevant
52 model.
53 """
54 probability = 1.0
55 for ii, function in enumerate(self.models):
56 function_parameters = self._get_function_parameters(function)
57 if (
58 self.cache
59 and self._cached_parameters[function] == function_parameters
60 ):
61 new_probability = self._cached_probability[function]
62 else:
63 new_probability = function(
64 data, **self._get_function_parameters(function)
65 )
66 if self.cache:
67 self._cached_parameters[function] = function_parameters
68 self._cached_probability[function] = new_probability
69 probability *= new_probability
70 return probability
72 def _get_function_parameters(self, func):
73 """
74 If the function is a class method we need to remove more arguments or
75 have the variable names provided in the class.
76 """
77 if hasattr(func, "variable_names"):
78 param_keys = func.variable_names
79 else:
80 param_keys = infer_args_from_function_except_n_args(func, n=0)
81 ignore = ["dataset", "data", "self", "cls"]
82 for key in ignore:
83 if key in param_keys:
84 del param_keys[param_keys.index(key)]
85 parameters = {key: self.parameters[key] for key in param_keys}
86 return parameters