Coverage for bilby/core/prior/conditional.py: 96%

129 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1from .base import Prior, PriorException 

2from .interpolated import Interped 

3from .analytical import DeltaFunction, PowerLaw, Uniform, LogUniform, \ 

4 SymmetricLogUniform, Cosine, Sine, Gaussian, TruncatedGaussian, HalfGaussian, \ 

5 LogNormal, Exponential, StudentT, Beta, Logistic, Cauchy, Gamma, ChiSquared, FermiDirac 

6from ..utils import infer_args_from_method, infer_parameters_from_function 

7 

8 

9def conditional_prior_factory(prior_class): 

10 class ConditionalPrior(prior_class): 

11 def __init__(self, condition_func, name=None, latex_label=None, unit=None, 

12 boundary=None, **reference_params): 

13 """ 

14 

15 Parameters 

16 ========== 

17 condition_func: func 

18 Functional form of the condition for this prior. The first function argument 

19 has to be a dictionary for the `reference_params` (see below). The following 

20 arguments are the required variables that are required before we can draw this 

21 prior. 

22 It needs to return a dictionary with the modified values for the 

23 `reference_params` that are being used in the next draw. 

24 For example if we have a Uniform prior for `x` depending on a different variable `y` 

25 `p(x|y)` with the boundaries linearly depending on y, then this 

26 could have the following form: 

27 

28 .. code-block:: python 

29 

30 def condition_func(reference_params, y): 

31 return dict( 

32 minimum=reference_params['minimum'] + y, 

33 maximum=reference_params['maximum'] + y 

34 ) 

35 

36 name: str, optional 

37 See superclass 

38 latex_label: str, optional 

39 See superclass 

40 unit: str, optional 

41 See superclass 

42 boundary: str, optional 

43 See superclass 

44 reference_params: 

45 Initial values for attributes such as `minimum`, `maximum`. 

46 This differs on the `prior_class`, for example for the Gaussian 

47 prior this is `mu` and `sigma`. 

48 """ 

49 if 'boundary' in infer_args_from_method(super(ConditionalPrior, self).__init__): 

50 super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, 

51 unit=unit, boundary=boundary, **reference_params) 

52 else: 

53 super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label, 

54 unit=unit, **reference_params) 

55 

56 self._required_variables = None 

57 self.condition_func = condition_func 

58 self._reference_params = reference_params 

59 self.__class__.__name__ = 'Conditional{}'.format(prior_class.__name__) 

60 self.__class__.__qualname__ = 'Conditional{}'.format(prior_class.__qualname__) 

61 

62 def sample(self, size=None, **required_variables): 

63 """Draw a sample from the prior 

64 

65 Parameters 

66 ========== 

67 size: int or tuple of ints, optional 

68 See superclass 

69 required_variables: 

70 Any required variables that this prior depends on 

71 

72 Returns 

73 ======= 

74 float: See superclass 

75 

76 """ 

77 from ..utils.random import rng 

78 

79 self.least_recently_sampled = self.rescale(rng.uniform(0, 1, size), **required_variables) 

80 return self.least_recently_sampled 

81 

82 def rescale(self, val, **required_variables): 

83 """ 

84 'Rescale' a sample from the unit line element to the prior. 

85 

86 Parameters 

87 ========== 

88 val: Union[float, int, array_like] 

89 See superclass 

90 required_variables: 

91 Any required variables that this prior depends on 

92 

93 

94 """ 

95 self.update_conditions(**required_variables) 

96 return super(ConditionalPrior, self).rescale(val) 

97 

98 def prob(self, val, **required_variables): 

99 """Return the prior probability of val. 

100 

101 Parameters 

102 ========== 

103 val: Union[float, int, array_like] 

104 See superclass 

105 required_variables: 

106 Any required variables that this prior depends on 

107 

108 

109 Returns 

110 ======= 

111 float: Prior probability of val 

112 """ 

113 self.update_conditions(**required_variables) 

114 return super(ConditionalPrior, self).prob(val) 

115 

116 def ln_prob(self, val, **required_variables): 

117 """Return the natural log prior probability of val. 

118 

119 Parameters 

120 ========== 

121 val: Union[float, int, array_like] 

122 See superclass 

123 required_variables: 

124 Any required variables that this prior depends on 

125 

126 

127 Returns 

128 ======= 

129 float: Natural log prior probability of val 

130 """ 

131 self.update_conditions(**required_variables) 

132 return super(ConditionalPrior, self).ln_prob(val) 

133 

134 def cdf(self, val, **required_variables): 

135 """Return the cdf of val. 

136 

137 Parameters 

138 ========== 

139 val: Union[float, int, array_like] 

140 See superclass 

141 required_variables: 

142 Any required variables that this prior depends on 

143 

144 

145 Returns 

146 ======= 

147 float: CDF of val 

148 """ 

149 self.update_conditions(**required_variables) 

150 return super(ConditionalPrior, self).cdf(val) 

151 

152 def update_conditions(self, **required_variables): 

153 """ 

154 This method updates the conditional parameters (depending on the parent class 

155 this could be e.g. `minimum`, `maximum`, `mu`, `sigma`, etc.) of this prior 

156 class depending on the required variables it depends on. 

157 

158 If no variables are given, the most recently used conditional parameters are kept 

159 

160 Parameters 

161 ========== 

162 required_variables: 

163 Any required variables that this prior depends on. If none are given, 

164 self.reference_params will be used. 

165 

166 """ 

167 if sorted(list(required_variables)) == sorted(self.required_variables): 

168 parameters = self.condition_func(self.reference_params.copy(), **required_variables) 

169 for key, value in parameters.items(): 

170 setattr(self, key, value) 

171 elif len(required_variables) == 0: 

172 return 

173 else: 

174 raise IllegalRequiredVariablesException("Expected kwargs for {}. Got kwargs for {} instead." 

175 .format(self.required_variables, 

176 list(required_variables.keys()))) 

177 

178 @property 

179 def reference_params(self): 

180 """ 

181 Initial values for attributes such as `minimum`, `maximum`. 

182 This depends on the `prior_class`, for example for the Gaussian 

183 prior this is `mu` and `sigma`. This is read-only. 

184 """ 

185 return self._reference_params 

186 

187 @property 

188 def condition_func(self): 

189 return self._condition_func 

190 

191 @condition_func.setter 

192 def condition_func(self, condition_func): 

193 if condition_func is None: 

194 self._condition_func = lambda reference_params: reference_params 

195 else: 

196 self._condition_func = condition_func 

197 self._required_variables = infer_parameters_from_function(self.condition_func) 

198 

199 @property 

200 def required_variables(self): 

201 """ The required variables to pass into the condition function. """ 

202 return self._required_variables 

203 

204 def get_instantiation_dict(self): 

205 instantiation_dict = super(ConditionalPrior, self).get_instantiation_dict() 

206 for key, value in self.reference_params.items(): 

207 instantiation_dict[key] = value 

208 return instantiation_dict 

209 

210 def reset_to_reference_parameters(self): 

211 """ 

212 Reset the object attributes to match the original reference parameters 

213 """ 

214 for key, value in self.reference_params.items(): 

215 setattr(self, key, value) 

216 

217 def __repr__(self): 

218 """Overrides the special method __repr__. 

219 

220 Returns a representation of this instance that resembles how it is instantiated. 

221 Works correctly for all child classes 

222 

223 Returns 

224 ======= 

225 str: A string representation of this instance 

226 

227 """ 

228 prior_name = self.__class__.__name__ 

229 instantiation_dict = self.get_instantiation_dict() 

230 instantiation_dict["condition_func"] = ".".join([ 

231 instantiation_dict["condition_func"].__module__, 

232 instantiation_dict["condition_func"].__name__ 

233 ]) 

234 args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key])) 

235 for key in instantiation_dict]) 

236 return "{}({})".format(prior_name, args) 

237 

238 return ConditionalPrior 

239 

240 

241class ConditionalBasePrior(conditional_prior_factory(Prior)): 

242 pass 

243 

244 

245class ConditionalUniform(conditional_prior_factory(Uniform)): 

246 pass 

247 

248 

249class ConditionalDeltaFunction(conditional_prior_factory(DeltaFunction)): 

250 pass 

251 

252 

253class ConditionalPowerLaw(conditional_prior_factory(PowerLaw)): 

254 pass 

255 

256 

257class ConditionalGaussian(conditional_prior_factory(Gaussian)): 

258 pass 

259 

260 

261class ConditionalLogUniform(conditional_prior_factory(LogUniform)): 

262 pass 

263 

264 

265class ConditionalSymmetricLogUniform(conditional_prior_factory(SymmetricLogUniform)): 

266 pass 

267 

268 

269class ConditionalCosine(conditional_prior_factory(Cosine)): 

270 pass 

271 

272 

273class ConditionalSine(conditional_prior_factory(Sine)): 

274 pass 

275 

276 

277class ConditionalTruncatedGaussian(conditional_prior_factory(TruncatedGaussian)): 

278 pass 

279 

280 

281class ConditionalHalfGaussian(conditional_prior_factory(HalfGaussian)): 

282 pass 

283 

284 

285class ConditionalLogNormal(conditional_prior_factory(LogNormal)): 

286 pass 

287 

288 

289class ConditionalExponential(conditional_prior_factory(Exponential)): 

290 pass 

291 

292 

293class ConditionalStudentT(conditional_prior_factory(StudentT)): 

294 pass 

295 

296 

297class ConditionalBeta(conditional_prior_factory(Beta)): 

298 pass 

299 

300 

301class ConditionalLogistic(conditional_prior_factory(Logistic)): 

302 pass 

303 

304 

305class ConditionalCauchy(conditional_prior_factory(Cauchy)): 

306 pass 

307 

308 

309class ConditionalGamma(conditional_prior_factory(Gamma)): 

310 pass 

311 

312 

313class ConditionalChiSquared(conditional_prior_factory(ChiSquared)): 

314 pass 

315 

316 

317class ConditionalFermiDirac(conditional_prior_factory(FermiDirac)): 

318 pass 

319 

320 

321class ConditionalInterped(conditional_prior_factory(Interped)): 

322 pass 

323 

324 

325class DirichletElement(ConditionalBeta): 

326 r""" 

327 Single element in a dirichlet distribution 

328 

329 The probability scales as 

330 

331 .. math:: 

332 p(x_n) \propto (x_\max - x_n)^{(N - n - 2)} 

333 

334 for :math:`x_n < x_\max`, where :math:`x_\max` is the sum of :math:`x_i` 

335 for :math:`i < n` 

336 

337 Examples 

338 ======== 

339 n_dimensions = 1: 

340 

341 .. math:: 

342 p(x_0) \propto 1 ; 0 < x_0 < 1 

343 

344 n_dimensions = 2: 

345 .. math:: 

346 p(x_0) &\propto (1 - x_0) ; 0 < x_0 < 1 

347 p(x_1) &\propto 1 ; 0 < x_1 < 1 

348 

349 Parameters 

350 ========== 

351 order: int 

352 Order of this element of the dirichlet distribution. 

353 n_dimensions: int 

354 Total number of elements of the dirichlet distribution 

355 label: str 

356 Label for the dirichlet distribution. 

357 This should be the same for all elements. 

358 

359 """ 

360 

361 def __init__(self, order, n_dimensions, label): 

362 """ """ 

363 super(DirichletElement, self).__init__( 

364 minimum=0, maximum=1, alpha=1, beta=n_dimensions - order - 1, 

365 name=label + str(order), 

366 condition_func=self.dirichlet_condition 

367 ) 

368 self.label = label 

369 self.n_dimensions = n_dimensions 

370 self.order = order 

371 self._required_variables = [ 

372 label + str(ii) for ii in range(order) 

373 ] 

374 self.__class__.__name__ = 'DirichletElement' 

375 self.__class__.__qualname__ = 'DirichletElement' 

376 

377 def dirichlet_condition(self, reference_parms, **kwargs): 

378 remaining = 1 - sum( 

379 [kwargs[self.label + str(ii)] for ii in range(self.order)] 

380 ) 

381 return dict(minimum=reference_parms["minimum"], maximum=remaining) 

382 

383 def __repr__(self): 

384 return Prior.__repr__(self) 

385 

386 def get_instantiation_dict(self): 

387 return Prior.get_instantiation_dict(self) 

388 

389 

390class ConditionalPriorException(PriorException): 

391 """ General base class for all conditional prior exceptions """ 

392 

393 

394class IllegalRequiredVariablesException(ConditionalPriorException): 

395 """ Exception class for exceptions relating to handling the required variables. """