Coverage for bilby/hyper/likelihood.py: 77%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1 

2import logging 

3 

4import numpy as np 

5 

6from ..core.likelihood import Likelihood 

7from .model import Model 

8from ..core.prior import PriorDict 

9 

10 

11class HyperparameterLikelihood(Likelihood): 

12 """ A likelihood for inferring hyperparameter posterior distributions 

13 

14 See Eq. (34) of https://arxiv.org/abs/1809.02293 for a definition. 

15 

16 Parameters 

17 ========== 

18 posteriors: list 

19 An list of pandas data frames of samples sets of samples. 

20 Each set may have a different size. 

21 hyper_prior: `bilby.hyper.model.Model` 

22 The population model, this can alternatively be a function. 

23 sampling_prior: `bilby.hyper.model.Model` 

24 The sampling prior, this can alternatively be a function. 

25 log_evidences: list, optional 

26 Log evidences for single runs to ensure proper normalisation 

27 of the hyperparameter likelihood. If not provided, the original 

28 evidences will be set to 0. This produces a Bayes factor between 

29 the sampling prior and the hyperparameterised model. 

30 max_samples: int, optional 

31 Maximum number of samples to use from each set. 

32 

33 """ 

34 

35 def __init__(self, posteriors, hyper_prior, sampling_prior=None, 

36 log_evidences=None, max_samples=1e100): 

37 if not isinstance(hyper_prior, Model): 

38 hyper_prior = Model([hyper_prior]) 

39 if sampling_prior is None: 

40 if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()): 

41 raise ValueError('Missing both sampling prior function and prior or log_prior ' 

42 'column in posterior dictionary. Must pass one or the other.') 

43 else: 

44 if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)): 

45 sampling_prior = Model([sampling_prior]) 

46 if log_evidences is not None: 

47 self.evidence_factor = np.sum(log_evidences) 

48 else: 

49 self.evidence_factor = np.nan 

50 self.posteriors = posteriors 

51 self.hyper_prior = hyper_prior 

52 self.sampling_prior = sampling_prior 

53 self.max_samples = max_samples 

54 super(HyperparameterLikelihood, self).__init__(hyper_prior.parameters) 

55 

56 self.data = self.resample_posteriors() 

57 self.n_posteriors = len(self.posteriors) 

58 self.samples_per_posterior = self.max_samples 

59 self.samples_factor =\ 

60 - self.n_posteriors * np.log(self.samples_per_posterior) 

61 

62 def log_likelihood_ratio(self): 

63 self.hyper_prior.parameters.update(self.parameters) 

64 log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data) / 

65 self.data['prior'], axis=-1))) 

66 log_l += self.samples_factor 

67 return np.nan_to_num(log_l) 

68 

69 def noise_log_likelihood(self): 

70 return self.evidence_factor 

71 

72 def log_likelihood(self): 

73 return self.noise_log_likelihood() + self.log_likelihood_ratio() 

74 

75 def resample_posteriors(self, max_samples=None): 

76 """ 

77 Convert list of pandas DataFrame object to dict of arrays. 

78 

79 Parameters 

80 ========== 

81 max_samples: int, opt 

82 Maximum number of samples to take from each posterior, 

83 default is length of shortest posterior chain. 

84 Returns 

85 ======= 

86 data: dict 

87 Dictionary containing arrays of size (n_posteriors, max_samples) 

88 There is a key for each shared key in self.posteriors. 

89 """ 

90 if max_samples is not None: 

91 self.max_samples = max_samples 

92 for posterior in self.posteriors: 

93 self.max_samples = min(len(posterior), self.max_samples) 

94 data = {key: [] for key in self.posteriors[0]} 

95 if 'log_prior' in data.keys(): 

96 data.pop('log_prior') 

97 if 'prior' not in data.keys(): 

98 data['prior'] = [] 

99 logging.debug('Downsampling to {} samples per posterior.'.format( 

100 self.max_samples)) 

101 for posterior in self.posteriors: 

102 temp = posterior.sample(self.max_samples) 

103 if self.sampling_prior is not None: 

104 temp['prior'] = self.sampling_prior.prob(temp, axis=0) 

105 elif 'log_prior' in temp.keys(): 

106 temp['prior'] = np.exp(temp['log_prior']) 

107 for key in data: 

108 data[key].append(temp[key]) 

109 for key in data: 

110 data[key] = np.array(data[key]) 

111 return data