Coverage for bilby/hyper/model.py: 73%

30 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1from ..core.utils import infer_args_from_function_except_n_args 

2 

3 

4class Model: 

5 r""" 

6 Population model that combines a set of factorizable models. 

7 

8 This should take population parameters and return the probability. 

9 

10 .. math:: 

11 

12 p(\theta | \Lambda) = \prod_{i} p_{i}(\theta | \Lambda) 

13 """ 

14 

15 def __init__(self, model_functions=None, cache=True): 

16 """ 

17 Parameters 

18 ========== 

19 model_functions: list 

20 List of callables to compute the probability. 

21 If this includes classes, the :code:`__call__`: method 

22 should return the probability. 

23 The requires variables are chosen at run time based on either 

24 inspection or querying a :code:`variable_names` attribute. 

25 cache: bool 

26 Whether to cache the value returned by the model functions, 

27 default=:code:`True`. The caching only looks at the parameters 

28 not the data, so should be used with caution. The caching also 

29 breaks :code:`jax` JIT compilation. 

30 """ 

31 self.models = model_functions 

32 self.cache = cache 

33 self._cached_parameters = {model: None for model in self.models} 

34 self._cached_probability = {model: None for model in self.models} 

35 

36 self.parameters = dict() 

37 

38 def prob(self, data, **kwargs): 

39 """ 

40 Compute the total population probability for the provided data given 

41 the keyword arguments. 

42 

43 Parameters 

44 ========== 

45 data: dict 

46 Dictionary containing the points at which to evaluate the 

47 population model. 

48 kwargs: dict 

49 The population parameters. These cannot include any of 

50 :code:`["dataset", "data", "self", "cls"]` unless the 

51 :code:`variable_names` attribute is available for the relevant 

52 model. 

53 """ 

54 probability = 1.0 

55 for ii, function in enumerate(self.models): 

56 function_parameters = self._get_function_parameters(function) 

57 if ( 

58 self.cache 

59 and self._cached_parameters[function] == function_parameters 

60 ): 

61 new_probability = self._cached_probability[function] 

62 else: 

63 new_probability = function( 

64 data, **self._get_function_parameters(function) 

65 ) 

66 if self.cache: 

67 self._cached_parameters[function] = function_parameters 

68 self._cached_probability[function] = new_probability 

69 probability *= new_probability 

70 return probability 

71 

72 def _get_function_parameters(self, func): 

73 """ 

74 If the function is a class method we need to remove more arguments or 

75 have the variable names provided in the class. 

76 """ 

77 if hasattr(func, "variable_names"): 

78 param_keys = func.variable_names 

79 else: 

80 param_keys = infer_args_from_function_except_n_args(func, n=0) 

81 ignore = ["dataset", "data", "self", "cls"] 

82 for key in ignore: 

83 if key in param_keys: 

84 del param_keys[param_keys.index(key)] 

85 parameters = {key: self.parameters[key] for key in param_keys} 

86 return parameters