Coverage for bilby/core/prior/dict.py: 89%

461 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import json 

2import os 

3import re 

4from importlib import import_module 

5from io import open as ioopen 

6 

7import numpy as np 

8 

9from .analytical import DeltaFunction 

10from .base import Prior, Constraint 

11from .joint import JointPrior 

12from ..utils import ( 

13 logger, 

14 check_directory_exists_and_if_not_mkdir, 

15 BilbyJsonEncoder, 

16 decode_bilby_json, 

17) 

18 

19 

20class PriorDict(dict): 

21 def __init__(self, dictionary=None, filename=None, conversion_function=None): 

22 """A dictionary of priors 

23 

24 Parameters 

25 ========== 

26 dictionary: Union[dict, str, None] 

27 If given, a dictionary to generate the prior set. 

28 filename: Union[str, None] 

29 If given, a file containing the prior to generate the prior set. 

30 conversion_function: func 

31 Function to convert between sampled parameters and constraints. 

32 Default is no conversion. 

33 """ 

34 super(PriorDict, self).__init__() 

35 if isinstance(dictionary, dict): 

36 self.from_dictionary(dictionary) 

37 elif type(dictionary) is str: 

38 logger.debug( 

39 'Argument "dictionary" is a string.' 

40 + " Assuming it is intended as a file name." 

41 ) 

42 self.from_file(dictionary) 

43 elif type(filename) is str: 

44 self.from_file(filename) 

45 elif dictionary is not None: 

46 raise ValueError("PriorDict input dictionary not understood") 

47 self._cached_normalizations = {} 

48 

49 self.convert_floats_to_delta_functions() 

50 

51 if conversion_function is not None: 

52 self.conversion_function = conversion_function 

53 else: 

54 self.conversion_function = self.default_conversion_function 

55 

56 def evaluate_constraints(self, sample): 

57 out_sample = self.conversion_function(sample) 

58 prob = 1 

59 for key in self: 

60 if isinstance(self[key], Constraint) and key in out_sample: 

61 prob *= self[key].prob(out_sample[key]) 

62 return prob 

63 

64 def default_conversion_function(self, sample): 

65 """ 

66 Placeholder parameter conversion function. 

67 

68 Parameters 

69 ========== 

70 sample: dict 

71 Dictionary to convert 

72 

73 Returns 

74 ======= 

75 sample: dict 

76 Same as input 

77 """ 

78 return sample 

79 

80 def to_file(self, outdir, label): 

81 """Write the prior distribution to file. 

82 

83 Parameters 

84 ========== 

85 outdir: str 

86 output directory name 

87 label: str 

88 Output file naming scheme 

89 """ 

90 

91 check_directory_exists_and_if_not_mkdir(outdir) 

92 prior_file = os.path.join(outdir, "{}.prior".format(label)) 

93 logger.debug("Writing priors to {}".format(prior_file)) 

94 joint_dists = [] 

95 with open(prior_file, "w") as outfile: 

96 for key in self.keys(): 

97 if JointPrior in self[key].__class__.__mro__: 

98 distname = "_".join(self[key].dist.names) + "_{}".format( 

99 self[key].dist.distname 

100 ) 

101 if distname not in joint_dists: 

102 joint_dists.append(distname) 

103 outfile.write("{} = {}\n".format(distname, self[key].dist)) 

104 diststr = repr(self[key].dist) 

105 priorstr = repr(self[key]) 

106 outfile.write( 

107 "{} = {}\n".format(key, priorstr.replace(diststr, distname)) 

108 ) 

109 else: 

110 outfile.write("{} = {}\n".format(key, self[key])) 

111 

112 def _get_json_dict(self): 

113 self.convert_floats_to_delta_functions() 

114 total_dict = {key: json.loads(self[key].to_json()) for key in self} 

115 total_dict["__prior_dict__"] = True 

116 total_dict["__module__"] = self.__module__ 

117 total_dict["__name__"] = self.__class__.__name__ 

118 return total_dict 

119 

120 def to_json(self, outdir, label): 

121 check_directory_exists_and_if_not_mkdir(outdir) 

122 prior_file = os.path.join(outdir, "{}_prior.json".format(label)) 

123 logger.debug("Writing priors to {}".format(prior_file)) 

124 with open(prior_file, "w") as outfile: 

125 json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder, indent=2) 

126 

127 def from_file(self, filename): 

128 """Reads in a prior from a file specification 

129 

130 Parameters 

131 ========== 

132 filename: str 

133 Name of the file to be read in 

134 

135 Notes 

136 ===== 

137 Lines beginning with '#' or empty lines will be ignored. 

138 Priors can be loaded from: 

139 

140 - bilby.core.prior as, e.g., :code:`foo = Uniform(minimum=0, maximum=1)` 

141 - floats, e.g., :code:`foo = 1` 

142 - bilby.gw.prior as, e.g., :code:`foo = bilby.gw.prior.AlignedSpin()` 

143 - other external modules, e.g., :code:`foo = my.module.CustomPrior(...)` 

144 

145 """ 

146 

147 comments = ["#", "\n"] 

148 prior = dict() 

149 with ioopen(filename, "r", encoding="unicode_escape") as f: 

150 for line in f: 

151 if line[0] in comments: 

152 continue 

153 line.replace(" ", "") 

154 elements = line.split("=") 

155 key = elements[0].replace(" ", "") 

156 val = "=".join(elements[1:]).strip() 

157 prior[key] = val 

158 self.from_dictionary(prior) 

159 

160 @classmethod 

161 def _get_from_json_dict(cls, prior_dict): 

162 try: 

163 class_ = getattr( 

164 import_module(prior_dict["__module__"]), prior_dict["__name__"] 

165 ) 

166 except ImportError: 

167 logger.debug( 

168 "Cannot import prior module {}.{}".format( 

169 prior_dict["__module__"], prior_dict["__name__"] 

170 ) 

171 ) 

172 class_ = cls 

173 except KeyError: 

174 logger.debug("Cannot find module name to load") 

175 class_ = cls 

176 for key in ["__module__", "__name__", "__prior_dict__"]: 

177 if key in prior_dict: 

178 del prior_dict[key] 

179 obj = class_(prior_dict) 

180 return obj 

181 

182 @classmethod 

183 def from_json(cls, filename): 

184 """Reads in a prior from a json file 

185 

186 Parameters 

187 ========== 

188 filename: str 

189 Name of the file to be read in 

190 """ 

191 with open(filename, "r") as ff: 

192 obj = json.load(ff, object_hook=decode_bilby_json) 

193 

194 # make sure priors containing JointDists are properly handled and point 

195 # to the same object when required 

196 jointdists = {} 

197 for key in obj: 

198 if isinstance(obj[key], JointPrior): 

199 for name in obj[key].dist.names: 

200 jointdists[name] = obj[key].dist 

201 # set dist for joint values so that they point to the same object 

202 for key in obj: 

203 if isinstance(obj[key], JointPrior): 

204 obj[key].dist = jointdists[key] 

205 

206 return obj 

207 

208 def from_dictionary(self, dictionary): 

209 mvgkwargs = {} 

210 for key in list(dictionary.keys()): 

211 val = dictionary[key] 

212 if isinstance(val, Prior): 

213 continue 

214 elif isinstance(val, (int, float)): 

215 dictionary[key] = DeltaFunction(peak=val) 

216 elif isinstance(val, str): 

217 cls = val.split("(")[0] 

218 args = "(".join(val.split("(")[1:])[:-1] 

219 try: 

220 dictionary[key] = DeltaFunction(peak=float(cls)) 

221 logger.debug("{} converted to DeltaFunction prior".format(key)) 

222 continue 

223 except ValueError: 

224 pass 

225 if "." in cls: 

226 module = ".".join(cls.split(".")[:-1]) 

227 cls = cls.split(".")[-1] 

228 else: 

229 module = __name__.replace( 

230 "." + os.path.basename(__file__).replace(".py", ""), "" 

231 ) 

232 try: 

233 cls = getattr(import_module(module), cls, cls) 

234 except ModuleNotFoundError: 

235 logger.error( 

236 "Cannot import prior class {} for entry: {}={}".format( 

237 cls, key, val 

238 ) 

239 ) 

240 raise 

241 if key.lower() in ["conversion_function", "condition_func"]: 

242 setattr(self, key, cls) 

243 elif isinstance(cls, str): 

244 if "(" in val: 

245 raise TypeError("Unable to parse prior class {}".format(cls)) 

246 else: 

247 continue 

248 elif cls.__name__ in [ 

249 "MultivariateGaussianDist", 

250 "MultivariateNormalDist", 

251 ]: 

252 dictionary.pop(key) 

253 if key not in mvgkwargs: 

254 mvgkwargs[key] = cls.from_repr(args) 

255 elif cls.__name__ in ["MultivariateGaussian", "MultivariateNormal"]: 

256 mgkwargs = { 

257 item[0].strip(): cls._parse_argument_string(item[1]) 

258 for item in cls._split_repr( 

259 ", ".join( 

260 [arg for arg in args.split(",") if "dist=" not in arg] 

261 ) 

262 ).items() 

263 } 

264 keymatch = re.match(r"dist=(?P<distkey>\S+),", args) 

265 if keymatch is None: 

266 raise ValueError( 

267 "'dist' argument for MultivariateGaussian is not specified" 

268 ) 

269 

270 if keymatch["distkey"] not in mvgkwargs: 

271 raise ValueError( 

272 f"MultivariateGaussianDist {keymatch['distkey']} must be defined before {cls.__name__}" 

273 ) 

274 

275 mgkwargs["dist"] = mvgkwargs[keymatch["distkey"]] 

276 dictionary[key] = cls(**mgkwargs) 

277 else: 

278 try: 

279 dictionary[key] = cls.from_repr(args) 

280 except TypeError as e: 

281 raise TypeError( 

282 "Unable to parse prior, bad entry: {} " 

283 "= {}. Error message {}".format(key, val, e) 

284 ) 

285 elif isinstance(val, dict): 

286 try: 

287 _class = getattr( 

288 import_module(val.get("__module__", "none")), 

289 val.get("__name__", "none"), 

290 ) 

291 dictionary[key] = _class(**val.get("kwargs", dict())) 

292 except ImportError: 

293 logger.debug( 

294 "Cannot import prior module {}.{}".format( 

295 val.get("__module__", "none"), val.get("__name__", "none") 

296 ) 

297 ) 

298 logger.warning( 

299 "Cannot convert {} into a prior object. " 

300 "Leaving as dictionary.".format(key) 

301 ) 

302 continue 

303 else: 

304 raise TypeError( 

305 "Unable to parse prior, bad entry: {} " 

306 "= {} of type {}".format(key, val, type(val)) 

307 ) 

308 self.update(dictionary) 

309 

310 def convert_floats_to_delta_functions(self): 

311 """Convert all float parameters to delta functions""" 

312 for key in self: 

313 if isinstance(self[key], Prior): 

314 continue 

315 elif isinstance(self[key], float) or isinstance(self[key], int): 

316 self[key] = DeltaFunction(self[key]) 

317 logger.debug("{} converted to delta function prior.".format(key)) 

318 else: 

319 logger.debug( 

320 "{} cannot be converted to delta function prior.".format(key) 

321 ) 

322 

323 def fill_priors(self, likelihood, default_priors_file=None): 

324 """ 

325 Fill dictionary of priors based on required parameters of likelihood 

326 

327 Any floats in prior will be converted to delta function prior. Any 

328 required, non-specified parameters will use the default. 

329 

330 Note: if `likelihood` has `non_standard_sampling_parameter_keys`, then 

331 this will set-up default priors for those as well. 

332 

333 Parameters 

334 ========== 

335 likelihood: bilby.likelihood.GravitationalWaveTransient instance 

336 Used to infer the set of parameters to fill the prior with 

337 default_priors_file: str, optional 

338 If given, a file containing the default priors. 

339 

340 

341 Returns 

342 ======= 

343 prior: dict 

344 The filled prior dictionary 

345 

346 """ 

347 

348 self.convert_floats_to_delta_functions() 

349 

350 missing_keys = set(likelihood.parameters) - set(self.keys()) 

351 

352 for missing_key in missing_keys: 

353 if not self.test_redundancy(missing_key): 

354 default_prior = create_default_prior(missing_key, default_priors_file) 

355 if default_prior is None: 

356 set_val = likelihood.parameters[missing_key] 

357 logger.warning( 

358 "Parameter {} has no default prior and is set to {}, this" 

359 " will not be sampled and may cause an error.".format( 

360 missing_key, set_val 

361 ) 

362 ) 

363 else: 

364 self[missing_key] = default_prior 

365 

366 for key in self: 

367 self.test_redundancy(key) 

368 

369 def sample(self, size=None): 

370 """Draw samples from the prior set 

371 

372 Parameters 

373 ========== 

374 size: int or tuple of ints, optional 

375 See numpy.random.uniform docs 

376 

377 Returns 

378 ======= 

379 dict: Dictionary of the samples 

380 """ 

381 return self.sample_subset_constrained(keys=list(self.keys()), size=size) 

382 

383 def sample_subset_constrained_as_array(self, keys=iter([]), size=None): 

384 """Return an array of samples 

385 

386 Parameters 

387 ========== 

388 keys: list 

389 A list of keys to sample in 

390 size: int 

391 The number of samples to draw 

392 

393 Returns 

394 ======= 

395 array: array_like 

396 An array of shape (len(key), size) of the samples (ordered by keys) 

397 """ 

398 samples_dict = self.sample_subset_constrained(keys=keys, size=size) 

399 samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()} 

400 samples_list = [samples_dict[key] for key in keys] 

401 return np.array(samples_list) 

402 

403 def sample_subset(self, keys=iter([]), size=None): 

404 """Draw samples from the prior set for parameters which are not a DeltaFunction 

405 

406 Parameters 

407 ========== 

408 keys: list 

409 List of prior keys to draw samples from 

410 size: int or tuple of ints, optional 

411 See numpy.random.uniform docs 

412 

413 Returns 

414 ======= 

415 dict: Dictionary of the drawn samples 

416 """ 

417 self.convert_floats_to_delta_functions() 

418 samples = dict() 

419 for key in keys: 

420 if isinstance(self[key], Constraint): 

421 continue 

422 elif isinstance(self[key], Prior): 

423 samples[key] = self[key].sample(size=size) 

424 else: 

425 logger.debug("{} not a known prior.".format(key)) 

426 return samples 

427 

428 @property 

429 def non_fixed_keys(self): 

430 keys = self.keys() 

431 keys = [k for k in keys if isinstance(self[k], Prior)] 

432 keys = [k for k in keys if self[k].is_fixed is False] 

433 keys = [k for k in keys if k not in self.constraint_keys] 

434 return keys 

435 

436 @property 

437 def fixed_keys(self): 

438 return [ 

439 k for k, p in self.items() if (p.is_fixed and k not in self.constraint_keys) 

440 ] 

441 

442 @property 

443 def constraint_keys(self): 

444 return [k for k, p in self.items() if isinstance(p, Constraint)] 

445 

446 def sample_subset_constrained(self, keys=iter([]), size=None): 

447 if size is None or size == 1: 

448 while True: 

449 sample = self.sample_subset(keys=keys, size=size) 

450 if self.evaluate_constraints(sample): 

451 return sample 

452 else: 

453 needed = np.prod(size) 

454 for key in keys.copy(): 

455 if isinstance(self[key], Constraint): 

456 del keys[keys.index(key)] 

457 all_samples = {key: np.array([]) for key in keys} 

458 _first_key = list(all_samples.keys())[0] 

459 while len(all_samples[_first_key]) < needed: 

460 samples = self.sample_subset(keys=keys, size=needed) 

461 keep = np.array(self.evaluate_constraints(samples), dtype=bool) 

462 for key in keys: 

463 all_samples[key] = np.hstack( 

464 [all_samples[key], samples[key][keep].flatten()] 

465 ) 

466 all_samples = { 

467 key: np.reshape(all_samples[key][:needed], size) for key in keys 

468 } 

469 return all_samples 

470 

471 def normalize_constraint_factor( 

472 self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10 

473 ): 

474 if keys in self._cached_normalizations.keys(): 

475 return self._cached_normalizations[keys] 

476 else: 

477 factor_estimates = [ 

478 self._estimate_normalization(keys, min_accept, sampling_chunk) 

479 for _ in range(nrepeats) 

480 ] 

481 factor = np.mean(factor_estimates) 

482 if np.std(factor_estimates) > 0: 

483 decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates)))) 

484 factor_rounded = np.round(factor, decimals) 

485 else: 

486 factor_rounded = factor 

487 self._cached_normalizations[keys] = factor_rounded 

488 return factor_rounded 

489 

490 def _estimate_normalization(self, keys, min_accept, sampling_chunk): 

491 samples = self.sample_subset(keys=keys, size=sampling_chunk) 

492 keep = np.atleast_1d(self.evaluate_constraints(samples)) 

493 if len(keep) == 1: 

494 self._cached_normalizations[keys] = 1 

495 return 1 

496 all_samples = {key: np.array([]) for key in keys} 

497 while np.count_nonzero(keep) < min_accept: 

498 samples = self.sample_subset(keys=keys, size=sampling_chunk) 

499 for key in samples: 

500 all_samples[key] = np.hstack([all_samples[key], samples[key].flatten()]) 

501 keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) 

502 factor = len(keep) / np.count_nonzero(keep) 

503 return factor 

504 

505 def prob(self, sample, **kwargs): 

506 """ 

507 

508 Parameters 

509 ========== 

510 sample: dict 

511 Dictionary of the samples of which we want to have the probability of 

512 kwargs: 

513 The keyword arguments are passed directly to `np.prod` 

514 

515 Returns 

516 ======= 

517 float: Joint probability of all individual sample probabilities 

518 

519 """ 

520 prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs) 

521 

522 return self.check_prob(sample, prob) 

523 

524 def check_prob(self, sample, prob): 

525 ratio = self.normalize_constraint_factor(tuple(sample.keys())) 

526 if np.all(prob == 0.0): 

527 return prob * ratio 

528 else: 

529 if isinstance(prob, float): 

530 if self.evaluate_constraints(sample): 

531 return prob * ratio 

532 else: 

533 return 0.0 

534 else: 

535 constrained_prob = np.zeros_like(prob) 

536 keep = np.array(self.evaluate_constraints(sample), dtype=bool) 

537 constrained_prob[keep] = prob[keep] * ratio 

538 return constrained_prob 

539 

540 def ln_prob(self, sample, axis=None, normalized=True): 

541 """ 

542 

543 Parameters 

544 ========== 

545 sample: dict 

546 Dictionary of the samples of which to calculate the log probability 

547 axis: None or int 

548 Axis along which the summation is performed 

549 normalized: bool 

550 When False, disables calculation of constraint normalization factor 

551 during prior probability computation. Default value is True. 

552 

553 Returns 

554 ======= 

555 float or ndarray: 

556 Joint log probability of all the individual sample probabilities 

557 

558 """ 

559 ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) 

560 return self.check_ln_prob(sample, ln_prob, 

561 normalized=normalized) 

562 

563 def check_ln_prob(self, sample, ln_prob, normalized=True): 

564 if normalized: 

565 ratio = self.normalize_constraint_factor(tuple(sample.keys())) 

566 else: 

567 ratio = 1 

568 if np.all(np.isinf(ln_prob)): 

569 return ln_prob 

570 else: 

571 if isinstance(ln_prob, float): 

572 if self.evaluate_constraints(sample): 

573 return ln_prob + np.log(ratio) 

574 else: 

575 return -np.inf 

576 else: 

577 constrained_ln_prob = -np.inf * np.ones_like(ln_prob) 

578 keep = np.array(self.evaluate_constraints(sample), dtype=bool) 

579 constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) 

580 return constrained_ln_prob 

581 

582 def cdf(self, sample): 

583 """Evaluate the cumulative distribution function at the provided points 

584 

585 Parameters 

586 ---------- 

587 sample: dict, pandas.DataFrame 

588 Dictionary of the samples of which to calculate the CDF 

589 

590 Returns 

591 ------- 

592 dict, pandas.DataFrame: Dictionary containing the CDF values 

593 

594 """ 

595 return sample.__class__( 

596 {key: self[key].cdf(sample) for key, sample in sample.items()} 

597 ) 

598 

599 def rescale(self, keys, theta): 

600 """Rescale samples from unit cube to prior 

601 

602 Parameters 

603 ========== 

604 keys: list 

605 List of prior keys to be rescaled 

606 theta: list 

607 List of randomly drawn values on a unit cube associated with the prior keys 

608 

609 Returns 

610 ======= 

611 list: List of floats containing the rescaled sample 

612 """ 

613 from matplotlib.cbook import flatten 

614 

615 return list( 

616 flatten([self[key].rescale(sample) for key, sample in zip(keys, theta)]) 

617 ) 

618 

619 def test_redundancy(self, key, disable_logging=False): 

620 """Empty redundancy test, should be overwritten in subclasses""" 

621 return False 

622 

623 def test_has_redundant_keys(self): 

624 """ 

625 Test whether there are redundant keys in self. 

626 

627 Returns 

628 ======= 

629 bool: Whether there are redundancies or not 

630 """ 

631 redundant = False 

632 for key in self: 

633 if isinstance(self[key], Constraint): 

634 continue 

635 temp = self.copy() 

636 del temp[key] 

637 if temp.test_redundancy(key, disable_logging=True): 

638 logger.warning( 

639 "{} is a redundant key in this {}.".format( 

640 key, self.__class__.__name__ 

641 ) 

642 ) 

643 redundant = True 

644 return redundant 

645 

646 def copy(self): 

647 """ 

648 We have to overwrite the copy method as it fails due to the presence of 

649 defaults. 

650 """ 

651 return self.__class__(dictionary=dict(self)) 

652 

653 

654class PriorDictException(Exception): 

655 """General base class for all prior dict exceptions""" 

656 

657 

658class ConditionalPriorDict(PriorDict): 

659 def __init__(self, dictionary=None, filename=None, conversion_function=None): 

660 """ 

661 

662 Parameters 

663 ========== 

664 dictionary: dict 

665 See parent class 

666 filename: str 

667 See parent class 

668 """ 

669 self._conditional_keys = [] 

670 self._unconditional_keys = [] 

671 self._rescale_keys = [] 

672 self._rescale_indexes = [] 

673 self._least_recently_rescaled_keys = [] 

674 super(ConditionalPriorDict, self).__init__( 

675 dictionary=dictionary, 

676 filename=filename, 

677 conversion_function=conversion_function, 

678 ) 

679 self._resolved = False 

680 self._resolve_conditions() 

681 

682 def _resolve_conditions(self): 

683 """ 

684 Resolves how priors depend on each other and automatically 

685 sorts them into the right order. 

686 1. All unconditional priors are put in front in arbitrary order 

687 2. We loop through all the unsorted conditional priors to find 

688 which one can go next 

689 3. We repeat step 2 len(self) number of times to make sure that 

690 all conditional priors will be sorted in order 

691 4. We set the `self._resolved` flag to True if all conditional 

692 priors were added in the right order 

693 """ 

694 self._unconditional_keys = [ 

695 key for key in self.keys() if not hasattr(self[key], "condition_func") 

696 ] 

697 conditional_keys_unsorted = [ 

698 key for key in self.keys() if hasattr(self[key], "condition_func") 

699 ] 

700 self._conditional_keys = [] 

701 for _ in range(len(self)): 

702 for key in conditional_keys_unsorted[:]: 

703 if self._check_conditions_resolved(key, self.sorted_keys): 

704 self._conditional_keys.append(key) 

705 conditional_keys_unsorted.remove(key) 

706 

707 self._resolved = True 

708 if len(conditional_keys_unsorted) != 0: 

709 self._resolved = False 

710 

711 def _check_conditions_resolved(self, key, sampled_keys): 

712 """Checks if all required variables have already been sampled so we can sample this key""" 

713 conditions_resolved = True 

714 for k in self[key].required_variables: 

715 if k not in sampled_keys: 

716 conditions_resolved = False 

717 return conditions_resolved 

718 

719 def sample_subset(self, keys=iter([]), size=None): 

720 self.convert_floats_to_delta_functions() 

721 add_delta_keys = [ 

722 key 

723 for key in self.keys() 

724 if key not in keys and isinstance(self[key], DeltaFunction) 

725 ] 

726 use_keys = add_delta_keys + list(keys) 

727 subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys}) 

728 if not subset_dict._resolved: 

729 raise IllegalConditionsException( 

730 "The current set of priors contains unresolvable conditions." 

731 ) 

732 samples = dict() 

733 for key in subset_dict.sorted_keys: 

734 if key not in keys or isinstance(self[key], Constraint): 

735 continue 

736 if isinstance(self[key], Prior): 

737 try: 

738 samples[key] = subset_dict[key].sample( 

739 size=size, **subset_dict.get_required_variables(key) 

740 ) 

741 except ValueError: 

742 # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw) 

743 # If that is the case, we sample each sample individually. 

744 required_variables = subset_dict.get_required_variables(key) 

745 samples[key] = np.zeros(size) 

746 for i in range(size): 

747 rvars = { 

748 key: value[i] for key, value in required_variables.items() 

749 } 

750 samples[key][i] = subset_dict[key].sample(**rvars) 

751 else: 

752 logger.debug("{} not a known prior.".format(key)) 

753 return samples 

754 

755 def get_required_variables(self, key): 

756 """Returns the required variables to sample a given conditional key. 

757 

758 Parameters 

759 ========== 

760 key : str 

761 Name of the key that we want to know the required variables for 

762 

763 Returns 

764 ======= 

765 dict: key/value pairs of the required variables 

766 """ 

767 return { 

768 k: self[k].least_recently_sampled 

769 for k in getattr(self[key], "required_variables", []) 

770 } 

771 

772 def prob(self, sample, **kwargs): 

773 """ 

774 

775 Parameters 

776 ========== 

777 sample: dict 

778 Dictionary of the samples of which we want to have the probability of 

779 kwargs: 

780 The keyword arguments are passed directly to `np.prod` 

781 

782 Returns 

783 ======= 

784 float: Joint probability of all individual sample probabilities 

785 

786 """ 

787 self._prepare_evaluation(*zip(*sample.items())) 

788 res = [ 

789 self[key].prob(sample[key], **self.get_required_variables(key)) 

790 for key in sample 

791 ] 

792 prob = np.prod(res, **kwargs) 

793 return self.check_prob(sample, prob) 

794 

795 def ln_prob(self, sample, axis=None, normalized=True): 

796 """ 

797 

798 Parameters 

799 ========== 

800 sample: dict 

801 Dictionary of the samples of which we want to have the log probability of 

802 axis: Union[None, int] 

803 Axis along which the summation is performed 

804 normalized: bool 

805 When False, disables calculation of constraint normalization factor 

806 during prior probability computation. Default value is True. 

807 

808 Returns 

809 ======= 

810 float: Joint log probability of all the individual sample probabilities 

811 

812 """ 

813 self._prepare_evaluation(*zip(*sample.items())) 

814 res = [ 

815 self[key].ln_prob(sample[key], **self.get_required_variables(key)) 

816 for key in sample 

817 ] 

818 ln_prob = np.sum(res, axis=axis) 

819 return self.check_ln_prob(sample, ln_prob, 

820 normalized=normalized) 

821 

822 def cdf(self, sample): 

823 self._prepare_evaluation(*zip(*sample.items())) 

824 res = { 

825 key: self[key].cdf(sample[key], **self.get_required_variables(key)) 

826 for key in sample 

827 } 

828 return sample.__class__(res) 

829 

830 def rescale(self, keys, theta): 

831 """Rescale samples from unit cube to prior 

832 

833 Parameters 

834 ========== 

835 keys: list 

836 List of prior keys to be rescaled 

837 theta: list 

838 List of randomly drawn values on a unit cube associated with the prior keys 

839 

840 Returns 

841 ======= 

842 list: List of floats containing the rescaled sample 

843 """ 

844 from matplotlib.cbook import flatten 

845 

846 keys = list(keys) 

847 theta = list(theta) 

848 self._check_resolved() 

849 self._update_rescale_keys(keys) 

850 result = dict() 

851 for key, index in zip( 

852 self.sorted_keys_without_fixed_parameters, self._rescale_indexes 

853 ): 

854 result[key] = self[key].rescale( 

855 theta[index], **self.get_required_variables(key) 

856 ) 

857 self[key].least_recently_sampled = result[key] 

858 return list(flatten([result[key] for key in keys])) 

859 

860 def _update_rescale_keys(self, keys): 

861 if not keys == self._least_recently_rescaled_keys: 

862 self._rescale_indexes = [ 

863 keys.index(element) 

864 for element in self.sorted_keys_without_fixed_parameters 

865 ] 

866 self._least_recently_rescaled_keys = keys 

867 

868 def _prepare_evaluation(self, keys, theta): 

869 self._check_resolved() 

870 for key, value in zip(keys, theta): 

871 self[key].least_recently_sampled = value 

872 

873 def _check_resolved(self): 

874 if not self._resolved: 

875 raise IllegalConditionsException( 

876 "The current set of priors contains unresolveable conditions." 

877 ) 

878 

879 @property 

880 def conditional_keys(self): 

881 return self._conditional_keys 

882 

883 @property 

884 def unconditional_keys(self): 

885 return self._unconditional_keys 

886 

887 @property 

888 def sorted_keys(self): 

889 return self.unconditional_keys + self.conditional_keys 

890 

891 @property 

892 def sorted_keys_without_fixed_parameters(self): 

893 return [ 

894 key 

895 for key in self.sorted_keys 

896 if not isinstance(self[key], (DeltaFunction, Constraint)) 

897 ] 

898 

899 def __setitem__(self, key, value): 

900 super(ConditionalPriorDict, self).__setitem__(key, value) 

901 self._resolve_conditions() 

902 

903 def __delitem__(self, key): 

904 super(ConditionalPriorDict, self).__delitem__(key) 

905 self._resolve_conditions() 

906 

907 

908class DirichletPriorDict(ConditionalPriorDict): 

909 def __init__(self, n_dim=None, label="dirichlet_"): 

910 from .conditional import DirichletElement 

911 

912 self.n_dim = n_dim 

913 self.label = label 

914 super(DirichletPriorDict, self).__init__(dictionary=dict()) 

915 for ii in range(n_dim - 1): 

916 self[label + "{}".format(ii)] = DirichletElement( 

917 order=ii, n_dimensions=n_dim, label=label 

918 ) 

919 

920 def copy(self, **kwargs): 

921 return self.__class__(n_dim=self.n_dim, label=self.label) 

922 

923 def _get_json_dict(self): 

924 total_dict = dict() 

925 total_dict["__prior_dict__"] = True 

926 total_dict["__module__"] = self.__module__ 

927 total_dict["__name__"] = self.__class__.__name__ 

928 total_dict["n_dim"] = self.n_dim 

929 total_dict["label"] = self.label 

930 return total_dict 

931 

932 @classmethod 

933 def _get_from_json_dict(cls, prior_dict): 

934 try: 

935 cls == getattr( 

936 import_module(prior_dict["__module__"]), prior_dict["__name__"] 

937 ) 

938 except ImportError: 

939 logger.debug( 

940 "Cannot import prior module {}.{}".format( 

941 prior_dict["__module__"], prior_dict["__name__"] 

942 ) 

943 ) 

944 except KeyError: 

945 logger.debug("Cannot find module name to load") 

946 for key in ["__module__", "__name__", "__prior_dict__"]: 

947 if key in prior_dict: 

948 del prior_dict[key] 

949 obj = cls(**prior_dict) 

950 return obj 

951 

952 

953class ConditionalPriorDictException(PriorDictException): 

954 """General base class for all conditional prior dict exceptions""" 

955 

956 

957def create_default_prior(name, default_priors_file=None): 

958 """Make a default prior for a parameter with a known name. 

959 

960 Parameters 

961 ========== 

962 name: str 

963 Parameter name 

964 default_priors_file: str, optional 

965 If given, a file containing the default priors. 

966 

967 Returns 

968 ======= 

969 prior: Prior 

970 Default prior distribution for that parameter, if unknown None is 

971 returned. 

972 """ 

973 

974 if default_priors_file is None: 

975 logger.debug("No prior file given.") 

976 prior = None 

977 else: 

978 default_priors = PriorDict(filename=default_priors_file) 

979 if name in default_priors.keys(): 

980 prior = default_priors[name] 

981 else: 

982 logger.debug("No default prior found for variable {}.".format(name)) 

983 prior = None 

984 return prior 

985 

986 

987class IllegalConditionsException(ConditionalPriorDictException): 

988 """Exception class to handle prior dicts that contain unresolvable conditions."""