Coverage for bilby/core/prior/base.py: 89%

214 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1from importlib import import_module 

2import json 

3import os 

4import re 

5 

6import numpy as np 

7import scipy.stats 

8from scipy.interpolate import interp1d 

9 

10from ..utils import ( 

11 infer_args_from_method, 

12 BilbyJsonEncoder, 

13 decode_bilby_json, 

14 logger, 

15 get_dict_with_properties, 

16) 

17 

18 

19class Prior(object): 

20 _default_latex_labels = {} 

21 

22 def __init__(self, name=None, latex_label=None, unit=None, minimum=-np.inf, 

23 maximum=np.inf, check_range_nonzero=True, boundary=None): 

24 """ Implements a Prior object 

25 

26 Parameters 

27 ========== 

28 name: str, optional 

29 Name associated with prior. 

30 latex_label: str, optional 

31 Latex label associated with prior, used for plotting. 

32 unit: str, optional 

33 If given, a Latex string describing the units of the parameter. 

34 minimum: float, optional 

35 Minimum of the domain, default=-np.inf 

36 maximum: float, optional 

37 Maximum of the domain, default=np.inf 

38 check_range_nonzero: boolean, optional 

39 If True, checks that the prior range is non-zero 

40 boundary: str, optional 

41 The boundary condition of the prior, can be 'periodic', 'reflective' 

42 Currently implemented in cpnest, dynesty and pymultinest. 

43 """ 

44 if check_range_nonzero and maximum <= minimum: 

45 raise ValueError( 

46 "maximum {} <= minimum {} for {} prior on {}".format( 

47 maximum, minimum, type(self).__name__, name 

48 ) 

49 ) 

50 self.name = name 

51 self.latex_label = latex_label 

52 self.unit = unit 

53 self.minimum = minimum 

54 self.maximum = maximum 

55 self.check_range_nonzero = check_range_nonzero 

56 self.least_recently_sampled = None 

57 self.boundary = boundary 

58 self._is_fixed = False 

59 

60 def __call__(self): 

61 """Overrides the __call__ special method. Calls the sample method. 

62 

63 Returns 

64 ======= 

65 float: The return value of the sample method. 

66 """ 

67 return self.sample() 

68 

69 def __eq__(self, other): 

70 """ 

71 Test equality of two prior objects. 

72 

73 Returns true iff: 

74 

75 - The class of the two priors are the same 

76 - Both priors have the same keys in the __dict__ attribute 

77 - The instantiation arguments match 

78 

79 We don't check that all entries the the __dict__ attribute 

80 are equal as some attributes are variable for conditional 

81 priors. 

82 

83 Parameters 

84 ========== 

85 other: Prior 

86 The prior to compare with 

87 

88 Returns 

89 ======= 

90 bool 

91 Whether the priors are equivalent 

92 

93 Notes 

94 ===== 

95 A special case is made for :code `scipy.stats.beta`: instances. 

96 It may be possible to remove this as we now only check instantiation 

97 arguments. 

98 

99 """ 

100 if self.__class__ != other.__class__: 

101 return False 

102 if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()): 

103 return False 

104 this_dict = self.get_instantiation_dict() 

105 other_dict = other.get_instantiation_dict() 

106 for key in this_dict: 

107 if key == "least_recently_sampled": 

108 continue 

109 if isinstance(this_dict[key], np.ndarray): 

110 if not np.array_equal(this_dict[key], other_dict[key]): 

111 return False 

112 elif isinstance(this_dict[key], type(scipy.stats.beta(1., 1.))): 

113 continue 

114 else: 

115 if not this_dict[key] == other_dict[key]: 

116 return False 

117 return True 

118 

119 def sample(self, size=None): 

120 """Draw a sample from the prior 

121 

122 Parameters 

123 ========== 

124 size: int or tuple of ints, optional 

125 See numpy.random.uniform docs 

126 

127 Returns 

128 ======= 

129 float: A random number between 0 and 1, rescaled to match the distribution of this Prior 

130 

131 """ 

132 from ..utils.random import rng 

133 

134 self.least_recently_sampled = self.rescale(rng.uniform(0, 1, size)) 

135 return self.least_recently_sampled 

136 

137 def rescale(self, val): 

138 """ 

139 'Rescale' a sample from the unit line element to the prior. 

140 

141 This should be overwritten by each subclass. 

142 

143 Parameters 

144 ========== 

145 val: Union[float, int, array_like] 

146 A random number between 0 and 1 

147 

148 Returns 

149 ======= 

150 None 

151 

152 """ 

153 return None 

154 

155 def prob(self, val): 

156 """Return the prior probability of val, this should be overwritten 

157 

158 Parameters 

159 ========== 

160 val: Union[float, int, array_like] 

161 

162 Returns 

163 ======= 

164 np.nan 

165 

166 """ 

167 return np.nan 

168 

169 def cdf(self, val): 

170 """ Generic method to calculate CDF, can be overwritten in subclass """ 

171 from scipy.integrate import cumulative_trapezoid 

172 if np.any(np.isinf([self.minimum, self.maximum])): 

173 raise ValueError( 

174 "Unable to use the generic CDF calculation for priors with" 

175 "infinite support") 

176 x = np.linspace(self.minimum, self.maximum, 1000) 

177 pdf = self.prob(x) 

178 cdf = cumulative_trapezoid(pdf, x, initial=0) 

179 interp = interp1d(x, cdf, assume_sorted=True, bounds_error=False, 

180 fill_value=(0, 1)) 

181 return interp(val) 

182 

183 def ln_prob(self, val): 

184 """Return the prior ln probability of val, this should be overwritten 

185 

186 Parameters 

187 ========== 

188 val: Union[float, int, array_like] 

189 

190 Returns 

191 ======= 

192 np.nan 

193 

194 """ 

195 with np.errstate(divide='ignore'): 

196 return np.log(self.prob(val)) 

197 

198 def is_in_prior_range(self, val): 

199 """Returns True if val is in the prior boundaries, zero otherwise 

200 

201 Parameters 

202 ========== 

203 val: Union[float, int, array_like] 

204 

205 Returns 

206 ======= 

207 np.nan 

208 

209 """ 

210 return (val >= self.minimum) & (val <= self.maximum) 

211 

212 def __repr__(self): 

213 """Overrides the special method __repr__. 

214 

215 Returns a representation of this instance that resembles how it is instantiated. 

216 Works correctly for all child classes 

217 

218 Returns 

219 ======= 

220 str: A string representation of this instance 

221 

222 """ 

223 prior_name = self.__class__.__name__ 

224 prior_module = self.__class__.__module__ 

225 instantiation_dict = self.get_instantiation_dict() 

226 args = ', '.join([f'{key}={repr(instantiation_dict[key])}' for key in instantiation_dict]) 

227 if "bilby.core.prior" in prior_module: 

228 return f"{prior_name}({args})" 

229 else: 

230 return f"{prior_module}.{prior_name}({args})" 

231 

232 @property 

233 def is_fixed(self): 

234 """ 

235 Returns True if the prior is fixed and should not be used in the sampler. Does this by checking if this instance 

236 is an instance of DeltaFunction. 

237 

238 

239 Returns 

240 ======= 

241 bool: Whether it's fixed or not! 

242 

243 """ 

244 return self._is_fixed 

245 

246 @property 

247 def latex_label(self): 

248 """Latex label that can be used for plots. 

249 

250 Draws from a set of default labels if no label is given 

251 

252 Returns 

253 ======= 

254 str: A latex representation for this prior 

255 

256 """ 

257 return self.__latex_label 

258 

259 @latex_label.setter 

260 def latex_label(self, latex_label=None): 

261 if latex_label is None: 

262 self.__latex_label = self.__default_latex_label 

263 else: 

264 self.__latex_label = latex_label 

265 

266 @property 

267 def unit(self): 

268 return self.__unit 

269 

270 @unit.setter 

271 def unit(self, unit): 

272 self.__unit = unit 

273 

274 @property 

275 def latex_label_with_unit(self): 

276 """ If a unit is specified, returns a string of the latex label and unit """ 

277 if self.unit is not None: 

278 return "{} [{}]".format(self.latex_label, self.unit) 

279 else: 

280 return self.latex_label 

281 

282 @property 

283 def minimum(self): 

284 return self._minimum 

285 

286 @minimum.setter 

287 def minimum(self, minimum): 

288 self._minimum = minimum 

289 

290 @property 

291 def maximum(self): 

292 return self._maximum 

293 

294 @maximum.setter 

295 def maximum(self, maximum): 

296 self._maximum = maximum 

297 

298 @property 

299 def width(self): 

300 return self.maximum - self.minimum 

301 

302 def get_instantiation_dict(self): 

303 subclass_args = infer_args_from_method(self.__init__) 

304 dict_with_properties = get_dict_with_properties(self) 

305 return {key: dict_with_properties[key] for key in subclass_args} 

306 

307 @property 

308 def boundary(self): 

309 return self._boundary 

310 

311 @boundary.setter 

312 def boundary(self, boundary): 

313 if boundary not in ['periodic', 'reflective', None]: 

314 raise ValueError('{} is not a valid setting for prior boundaries'.format(boundary)) 

315 self._boundary = boundary 

316 

317 @property 

318 def __default_latex_label(self): 

319 if self.name in self._default_latex_labels.keys(): 

320 label = self._default_latex_labels[self.name] 

321 else: 

322 label = self.name 

323 return label 

324 

325 def to_json(self): 

326 return json.dumps(self, cls=BilbyJsonEncoder) 

327 

328 @classmethod 

329 def from_json(cls, dct): 

330 return decode_bilby_json(dct) 

331 

332 @classmethod 

333 def from_repr(cls, string): 

334 """Generate the prior from its __repr__""" 

335 return cls._from_repr(string) 

336 

337 @classmethod 

338 def _from_repr(cls, string): 

339 subclass_args = infer_args_from_method(cls.__init__) 

340 

341 string = string.replace(' ', '') 

342 kwargs = cls._split_repr(string) 

343 for key in kwargs: 

344 val = kwargs[key] 

345 if key not in subclass_args and not hasattr(cls, "reference_params"): 

346 raise AttributeError('Unknown argument {} for class {}'.format( 

347 key, cls.__name__)) 

348 else: 

349 kwargs[key] = cls._parse_argument_string(val) 

350 if key in ["condition_func", "conversion_function"] and isinstance(kwargs[key], str): 

351 if "." in kwargs[key]: 

352 module = '.'.join(kwargs[key].split('.')[:-1]) 

353 name = kwargs[key].split('.')[-1] 

354 else: 

355 module = __name__ 

356 name = kwargs[key] 

357 kwargs[key] = getattr(import_module(module), name) 

358 return cls(**kwargs) 

359 

360 @classmethod 

361 def _split_repr(cls, string): 

362 subclass_args = infer_args_from_method(cls.__init__) 

363 args = string.split(',') 

364 remove = list() 

365 for ii, key in enumerate(args): 

366 if '(' in key: 

367 jj = ii 

368 while ')' not in args[jj]: 

369 jj += 1 

370 args[ii] = ','.join([args[ii], args[jj]]).strip() 

371 remove.append(jj) 

372 remove.reverse() 

373 for ii in remove: 

374 del args[ii] 

375 kwargs = dict() 

376 for ii, arg in enumerate(args): 

377 if '=' not in arg: 

378 logger.debug( 

379 'Reading priors with non-keyword arguments is dangerous!') 

380 key = subclass_args[ii] 

381 val = arg 

382 else: 

383 split_arg = arg.split('=') 

384 key = split_arg[0] 

385 val = '='.join(split_arg[1:]) 

386 kwargs[key] = val 

387 return kwargs 

388 

389 @classmethod 

390 def _parse_argument_string(cls, val): 

391 """ 

392 Parse a string into the appropriate type for prior reading. 

393 

394 Four tests are applied in the following order: 

395 

396 - If the string is 'None': 

397 `None` is returned. 

398 - Else If the string is a raw string, e.g., r'foo': 

399 A stripped version of the string is returned, e.g., foo. 

400 - Else If the string contains ', e.g., 'foo': 

401 A stripped version of the string is returned, e.g., foo. 

402 - Else If the string contains an open parenthesis, (: 

403 The string is interpreted as a call to instantiate another prior 

404 class, Bilby will attempt to recursively construct that prior, 

405 e.g., Uniform(minimum=0, maximum=1), my.custom.PriorClass(**kwargs). 

406 - Else If the string contains a ".": 

407 It is treated as a path to a Python function and imported, e.g., 

408 "some_module.some_function" returns 

409 :code:`import some_module; return some_module.some_function` 

410 - Else: 

411 Try to evaluate the string using `eval`. Only built-in functions 

412 and numpy methods can be used, e.g., np.pi / 2, 1.57. 

413 

414 

415 Parameters 

416 ========== 

417 val: str 

418 The string version of the argument 

419 

420 Returns 

421 ======= 

422 val: object 

423 The parsed version of the argument. 

424 

425 Raises 

426 ====== 

427 TypeError: 

428 If val cannot be parsed as described above. 

429 """ 

430 if val == 'None': 

431 val = None 

432 elif re.sub(r'\'.*\'', '', val) in ['r', 'u']: 

433 val = val[2:-1] 

434 elif val.startswith("'") and val.endswith("'"): 

435 val = val.strip("'") 

436 elif '(' in val and not val.startswith(("[", "{")): 

437 other_cls = val.split('(')[0] 

438 vals = '('.join(val.split('(')[1:])[:-1] 

439 if "." in other_cls: 

440 module = '.'.join(other_cls.split('.')[:-1]) 

441 other_cls = other_cls.split('.')[-1] 

442 else: 

443 module = __name__.replace('.' + os.path.basename(__file__).replace('.py', ''), '') 

444 other_cls = getattr(import_module(module), other_cls) 

445 val = other_cls.from_repr(vals) 

446 else: 

447 try: 

448 val = eval(val, dict(), dict(np=np, inf=np.inf, pi=np.pi)) 

449 except NameError: 

450 if "." in val: 

451 module = '.'.join(val.split('.')[:-1]) 

452 func = val.split('.')[-1] 

453 new_val = getattr(import_module(module), func, val) 

454 if val == new_val: 

455 raise TypeError( 

456 "Cannot evaluate prior, " 

457 f"failed to parse argument {val}" 

458 ) 

459 else: 

460 val = new_val 

461 return val 

462 

463 

464class Constraint(Prior): 

465 

466 def __init__(self, minimum, maximum, name=None, latex_label=None, 

467 unit=None): 

468 super(Constraint, self).__init__(minimum=minimum, maximum=maximum, name=name, 

469 latex_label=latex_label, unit=unit) 

470 self._is_fixed = True 

471 

472 def prob(self, val): 

473 return (val > self.minimum) & (val < self.maximum) 

474 

475 

476class PriorException(Exception): 

477 """ General base class for all prior exceptions """