Coverage for bilby/hyper/likelihood.py: 77%
57 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
2import logging
4import numpy as np
6from ..core.likelihood import Likelihood
7from .model import Model
8from ..core.prior import PriorDict
11class HyperparameterLikelihood(Likelihood):
12 """ A likelihood for inferring hyperparameter posterior distributions
14 See Eq. (34) of https://arxiv.org/abs/1809.02293 for a definition.
16 Parameters
17 ==========
18 posteriors: list
19 An list of pandas data frames of samples sets of samples.
20 Each set may have a different size.
21 hyper_prior: `bilby.hyper.model.Model`
22 The population model, this can alternatively be a function.
23 sampling_prior: `bilby.hyper.model.Model`
24 The sampling prior, this can alternatively be a function.
25 log_evidences: list, optional
26 Log evidences for single runs to ensure proper normalisation
27 of the hyperparameter likelihood. If not provided, the original
28 evidences will be set to 0. This produces a Bayes factor between
29 the sampling prior and the hyperparameterised model.
30 max_samples: int, optional
31 Maximum number of samples to use from each set.
33 """
35 def __init__(self, posteriors, hyper_prior, sampling_prior=None,
36 log_evidences=None, max_samples=1e100):
37 if not isinstance(hyper_prior, Model):
38 hyper_prior = Model([hyper_prior])
39 if sampling_prior is None:
40 if ('log_prior' not in posteriors[0].keys()) and ('prior' not in posteriors[0].keys()):
41 raise ValueError('Missing both sampling prior function and prior or log_prior '
42 'column in posterior dictionary. Must pass one or the other.')
43 else:
44 if not (isinstance(sampling_prior, Model) or isinstance(sampling_prior, PriorDict)):
45 sampling_prior = Model([sampling_prior])
46 if log_evidences is not None:
47 self.evidence_factor = np.sum(log_evidences)
48 else:
49 self.evidence_factor = np.nan
50 self.posteriors = posteriors
51 self.hyper_prior = hyper_prior
52 self.sampling_prior = sampling_prior
53 self.max_samples = max_samples
54 super(HyperparameterLikelihood, self).__init__(hyper_prior.parameters)
56 self.data = self.resample_posteriors()
57 self.n_posteriors = len(self.posteriors)
58 self.samples_per_posterior = self.max_samples
59 self.samples_factor =\
60 - self.n_posteriors * np.log(self.samples_per_posterior)
62 def log_likelihood_ratio(self):
63 self.hyper_prior.parameters.update(self.parameters)
64 log_l = np.sum(np.log(np.sum(self.hyper_prior.prob(self.data) /
65 self.data['prior'], axis=-1)))
66 log_l += self.samples_factor
67 return np.nan_to_num(log_l)
69 def noise_log_likelihood(self):
70 return self.evidence_factor
72 def log_likelihood(self):
73 return self.noise_log_likelihood() + self.log_likelihood_ratio()
75 def resample_posteriors(self, max_samples=None):
76 """
77 Convert list of pandas DataFrame object to dict of arrays.
79 Parameters
80 ==========
81 max_samples: int, opt
82 Maximum number of samples to take from each posterior,
83 default is length of shortest posterior chain.
84 Returns
85 =======
86 data: dict
87 Dictionary containing arrays of size (n_posteriors, max_samples)
88 There is a key for each shared key in self.posteriors.
89 """
90 if max_samples is not None:
91 self.max_samples = max_samples
92 for posterior in self.posteriors:
93 self.max_samples = min(len(posterior), self.max_samples)
94 data = {key: [] for key in self.posteriors[0]}
95 if 'log_prior' in data.keys():
96 data.pop('log_prior')
97 if 'prior' not in data.keys():
98 data['prior'] = []
99 logging.debug('Downsampling to {} samples per posterior.'.format(
100 self.max_samples))
101 for posterior in self.posteriors:
102 temp = posterior.sample(self.max_samples)
103 if self.sampling_prior is not None:
104 temp['prior'] = self.sampling_prior.prob(temp, axis=0)
105 elif 'log_prior' in temp.keys():
106 temp['prior'] = np.exp(temp['log_prior'])
107 for key in data:
108 data[key].append(temp[key])
109 for key in data:
110 data[key] = np.array(data[key])
111 return data