Coverage for bilby/core/prior/dict.py: 89%
461 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import json
2import os
3import re
4from importlib import import_module
5from io import open as ioopen
7import numpy as np
9from .analytical import DeltaFunction
10from .base import Prior, Constraint
11from .joint import JointPrior
12from ..utils import (
13 logger,
14 check_directory_exists_and_if_not_mkdir,
15 BilbyJsonEncoder,
16 decode_bilby_json,
17)
20class PriorDict(dict):
21 def __init__(self, dictionary=None, filename=None, conversion_function=None):
22 """A dictionary of priors
24 Parameters
25 ==========
26 dictionary: Union[dict, str, None]
27 If given, a dictionary to generate the prior set.
28 filename: Union[str, None]
29 If given, a file containing the prior to generate the prior set.
30 conversion_function: func
31 Function to convert between sampled parameters and constraints.
32 Default is no conversion.
33 """
34 super(PriorDict, self).__init__()
35 if isinstance(dictionary, dict):
36 self.from_dictionary(dictionary)
37 elif type(dictionary) is str:
38 logger.debug(
39 'Argument "dictionary" is a string.'
40 + " Assuming it is intended as a file name."
41 )
42 self.from_file(dictionary)
43 elif type(filename) is str:
44 self.from_file(filename)
45 elif dictionary is not None:
46 raise ValueError("PriorDict input dictionary not understood")
47 self._cached_normalizations = {}
49 self.convert_floats_to_delta_functions()
51 if conversion_function is not None:
52 self.conversion_function = conversion_function
53 else:
54 self.conversion_function = self.default_conversion_function
56 def evaluate_constraints(self, sample):
57 out_sample = self.conversion_function(sample)
58 prob = 1
59 for key in self:
60 if isinstance(self[key], Constraint) and key in out_sample:
61 prob *= self[key].prob(out_sample[key])
62 return prob
64 def default_conversion_function(self, sample):
65 """
66 Placeholder parameter conversion function.
68 Parameters
69 ==========
70 sample: dict
71 Dictionary to convert
73 Returns
74 =======
75 sample: dict
76 Same as input
77 """
78 return sample
80 def to_file(self, outdir, label):
81 """Write the prior distribution to file.
83 Parameters
84 ==========
85 outdir: str
86 output directory name
87 label: str
88 Output file naming scheme
89 """
91 check_directory_exists_and_if_not_mkdir(outdir)
92 prior_file = os.path.join(outdir, "{}.prior".format(label))
93 logger.debug("Writing priors to {}".format(prior_file))
94 joint_dists = []
95 with open(prior_file, "w") as outfile:
96 for key in self.keys():
97 if JointPrior in self[key].__class__.__mro__:
98 distname = "_".join(self[key].dist.names) + "_{}".format(
99 self[key].dist.distname
100 )
101 if distname not in joint_dists:
102 joint_dists.append(distname)
103 outfile.write("{} = {}\n".format(distname, self[key].dist))
104 diststr = repr(self[key].dist)
105 priorstr = repr(self[key])
106 outfile.write(
107 "{} = {}\n".format(key, priorstr.replace(diststr, distname))
108 )
109 else:
110 outfile.write("{} = {}\n".format(key, self[key]))
112 def _get_json_dict(self):
113 self.convert_floats_to_delta_functions()
114 total_dict = {key: json.loads(self[key].to_json()) for key in self}
115 total_dict["__prior_dict__"] = True
116 total_dict["__module__"] = self.__module__
117 total_dict["__name__"] = self.__class__.__name__
118 return total_dict
120 def to_json(self, outdir, label):
121 check_directory_exists_and_if_not_mkdir(outdir)
122 prior_file = os.path.join(outdir, "{}_prior.json".format(label))
123 logger.debug("Writing priors to {}".format(prior_file))
124 with open(prior_file, "w") as outfile:
125 json.dump(self._get_json_dict(), outfile, cls=BilbyJsonEncoder, indent=2)
127 def from_file(self, filename):
128 """Reads in a prior from a file specification
130 Parameters
131 ==========
132 filename: str
133 Name of the file to be read in
135 Notes
136 =====
137 Lines beginning with '#' or empty lines will be ignored.
138 Priors can be loaded from:
140 - bilby.core.prior as, e.g., :code:`foo = Uniform(minimum=0, maximum=1)`
141 - floats, e.g., :code:`foo = 1`
142 - bilby.gw.prior as, e.g., :code:`foo = bilby.gw.prior.AlignedSpin()`
143 - other external modules, e.g., :code:`foo = my.module.CustomPrior(...)`
145 """
147 comments = ["#", "\n"]
148 prior = dict()
149 with ioopen(filename, "r", encoding="unicode_escape") as f:
150 for line in f:
151 if line[0] in comments:
152 continue
153 line.replace(" ", "")
154 elements = line.split("=")
155 key = elements[0].replace(" ", "")
156 val = "=".join(elements[1:]).strip()
157 prior[key] = val
158 self.from_dictionary(prior)
160 @classmethod
161 def _get_from_json_dict(cls, prior_dict):
162 try:
163 class_ = getattr(
164 import_module(prior_dict["__module__"]), prior_dict["__name__"]
165 )
166 except ImportError:
167 logger.debug(
168 "Cannot import prior module {}.{}".format(
169 prior_dict["__module__"], prior_dict["__name__"]
170 )
171 )
172 class_ = cls
173 except KeyError:
174 logger.debug("Cannot find module name to load")
175 class_ = cls
176 for key in ["__module__", "__name__", "__prior_dict__"]:
177 if key in prior_dict:
178 del prior_dict[key]
179 obj = class_(prior_dict)
180 return obj
182 @classmethod
183 def from_json(cls, filename):
184 """Reads in a prior from a json file
186 Parameters
187 ==========
188 filename: str
189 Name of the file to be read in
190 """
191 with open(filename, "r") as ff:
192 obj = json.load(ff, object_hook=decode_bilby_json)
194 # make sure priors containing JointDists are properly handled and point
195 # to the same object when required
196 jointdists = {}
197 for key in obj:
198 if isinstance(obj[key], JointPrior):
199 for name in obj[key].dist.names:
200 jointdists[name] = obj[key].dist
201 # set dist for joint values so that they point to the same object
202 for key in obj:
203 if isinstance(obj[key], JointPrior):
204 obj[key].dist = jointdists[key]
206 return obj
208 def from_dictionary(self, dictionary):
209 mvgkwargs = {}
210 for key in list(dictionary.keys()):
211 val = dictionary[key]
212 if isinstance(val, Prior):
213 continue
214 elif isinstance(val, (int, float)):
215 dictionary[key] = DeltaFunction(peak=val)
216 elif isinstance(val, str):
217 cls = val.split("(")[0]
218 args = "(".join(val.split("(")[1:])[:-1]
219 try:
220 dictionary[key] = DeltaFunction(peak=float(cls))
221 logger.debug("{} converted to DeltaFunction prior".format(key))
222 continue
223 except ValueError:
224 pass
225 if "." in cls:
226 module = ".".join(cls.split(".")[:-1])
227 cls = cls.split(".")[-1]
228 else:
229 module = __name__.replace(
230 "." + os.path.basename(__file__).replace(".py", ""), ""
231 )
232 try:
233 cls = getattr(import_module(module), cls, cls)
234 except ModuleNotFoundError:
235 logger.error(
236 "Cannot import prior class {} for entry: {}={}".format(
237 cls, key, val
238 )
239 )
240 raise
241 if key.lower() in ["conversion_function", "condition_func"]:
242 setattr(self, key, cls)
243 elif isinstance(cls, str):
244 if "(" in val:
245 raise TypeError("Unable to parse prior class {}".format(cls))
246 else:
247 continue
248 elif cls.__name__ in [
249 "MultivariateGaussianDist",
250 "MultivariateNormalDist",
251 ]:
252 dictionary.pop(key)
253 if key not in mvgkwargs:
254 mvgkwargs[key] = cls.from_repr(args)
255 elif cls.__name__ in ["MultivariateGaussian", "MultivariateNormal"]:
256 mgkwargs = {
257 item[0].strip(): cls._parse_argument_string(item[1])
258 for item in cls._split_repr(
259 ", ".join(
260 [arg for arg in args.split(",") if "dist=" not in arg]
261 )
262 ).items()
263 }
264 keymatch = re.match(r"dist=(?P<distkey>\S+),", args)
265 if keymatch is None:
266 raise ValueError(
267 "'dist' argument for MultivariateGaussian is not specified"
268 )
270 if keymatch["distkey"] not in mvgkwargs:
271 raise ValueError(
272 f"MultivariateGaussianDist {keymatch['distkey']} must be defined before {cls.__name__}"
273 )
275 mgkwargs["dist"] = mvgkwargs[keymatch["distkey"]]
276 dictionary[key] = cls(**mgkwargs)
277 else:
278 try:
279 dictionary[key] = cls.from_repr(args)
280 except TypeError as e:
281 raise TypeError(
282 "Unable to parse prior, bad entry: {} "
283 "= {}. Error message {}".format(key, val, e)
284 )
285 elif isinstance(val, dict):
286 try:
287 _class = getattr(
288 import_module(val.get("__module__", "none")),
289 val.get("__name__", "none"),
290 )
291 dictionary[key] = _class(**val.get("kwargs", dict()))
292 except ImportError:
293 logger.debug(
294 "Cannot import prior module {}.{}".format(
295 val.get("__module__", "none"), val.get("__name__", "none")
296 )
297 )
298 logger.warning(
299 "Cannot convert {} into a prior object. "
300 "Leaving as dictionary.".format(key)
301 )
302 continue
303 else:
304 raise TypeError(
305 "Unable to parse prior, bad entry: {} "
306 "= {} of type {}".format(key, val, type(val))
307 )
308 self.update(dictionary)
310 def convert_floats_to_delta_functions(self):
311 """Convert all float parameters to delta functions"""
312 for key in self:
313 if isinstance(self[key], Prior):
314 continue
315 elif isinstance(self[key], float) or isinstance(self[key], int):
316 self[key] = DeltaFunction(self[key])
317 logger.debug("{} converted to delta function prior.".format(key))
318 else:
319 logger.debug(
320 "{} cannot be converted to delta function prior.".format(key)
321 )
323 def fill_priors(self, likelihood, default_priors_file=None):
324 """
325 Fill dictionary of priors based on required parameters of likelihood
327 Any floats in prior will be converted to delta function prior. Any
328 required, non-specified parameters will use the default.
330 Note: if `likelihood` has `non_standard_sampling_parameter_keys`, then
331 this will set-up default priors for those as well.
333 Parameters
334 ==========
335 likelihood: bilby.likelihood.GravitationalWaveTransient instance
336 Used to infer the set of parameters to fill the prior with
337 default_priors_file: str, optional
338 If given, a file containing the default priors.
341 Returns
342 =======
343 prior: dict
344 The filled prior dictionary
346 """
348 self.convert_floats_to_delta_functions()
350 missing_keys = set(likelihood.parameters) - set(self.keys())
352 for missing_key in missing_keys:
353 if not self.test_redundancy(missing_key):
354 default_prior = create_default_prior(missing_key, default_priors_file)
355 if default_prior is None:
356 set_val = likelihood.parameters[missing_key]
357 logger.warning(
358 "Parameter {} has no default prior and is set to {}, this"
359 " will not be sampled and may cause an error.".format(
360 missing_key, set_val
361 )
362 )
363 else:
364 self[missing_key] = default_prior
366 for key in self:
367 self.test_redundancy(key)
369 def sample(self, size=None):
370 """Draw samples from the prior set
372 Parameters
373 ==========
374 size: int or tuple of ints, optional
375 See numpy.random.uniform docs
377 Returns
378 =======
379 dict: Dictionary of the samples
380 """
381 return self.sample_subset_constrained(keys=list(self.keys()), size=size)
383 def sample_subset_constrained_as_array(self, keys=iter([]), size=None):
384 """Return an array of samples
386 Parameters
387 ==========
388 keys: list
389 A list of keys to sample in
390 size: int
391 The number of samples to draw
393 Returns
394 =======
395 array: array_like
396 An array of shape (len(key), size) of the samples (ordered by keys)
397 """
398 samples_dict = self.sample_subset_constrained(keys=keys, size=size)
399 samples_dict = {key: np.atleast_1d(val) for key, val in samples_dict.items()}
400 samples_list = [samples_dict[key] for key in keys]
401 return np.array(samples_list)
403 def sample_subset(self, keys=iter([]), size=None):
404 """Draw samples from the prior set for parameters which are not a DeltaFunction
406 Parameters
407 ==========
408 keys: list
409 List of prior keys to draw samples from
410 size: int or tuple of ints, optional
411 See numpy.random.uniform docs
413 Returns
414 =======
415 dict: Dictionary of the drawn samples
416 """
417 self.convert_floats_to_delta_functions()
418 samples = dict()
419 for key in keys:
420 if isinstance(self[key], Constraint):
421 continue
422 elif isinstance(self[key], Prior):
423 samples[key] = self[key].sample(size=size)
424 else:
425 logger.debug("{} not a known prior.".format(key))
426 return samples
428 @property
429 def non_fixed_keys(self):
430 keys = self.keys()
431 keys = [k for k in keys if isinstance(self[k], Prior)]
432 keys = [k for k in keys if self[k].is_fixed is False]
433 keys = [k for k in keys if k not in self.constraint_keys]
434 return keys
436 @property
437 def fixed_keys(self):
438 return [
439 k for k, p in self.items() if (p.is_fixed and k not in self.constraint_keys)
440 ]
442 @property
443 def constraint_keys(self):
444 return [k for k, p in self.items() if isinstance(p, Constraint)]
446 def sample_subset_constrained(self, keys=iter([]), size=None):
447 if size is None or size == 1:
448 while True:
449 sample = self.sample_subset(keys=keys, size=size)
450 if self.evaluate_constraints(sample):
451 return sample
452 else:
453 needed = np.prod(size)
454 for key in keys.copy():
455 if isinstance(self[key], Constraint):
456 del keys[keys.index(key)]
457 all_samples = {key: np.array([]) for key in keys}
458 _first_key = list(all_samples.keys())[0]
459 while len(all_samples[_first_key]) < needed:
460 samples = self.sample_subset(keys=keys, size=needed)
461 keep = np.array(self.evaluate_constraints(samples), dtype=bool)
462 for key in keys:
463 all_samples[key] = np.hstack(
464 [all_samples[key], samples[key][keep].flatten()]
465 )
466 all_samples = {
467 key: np.reshape(all_samples[key][:needed], size) for key in keys
468 }
469 return all_samples
471 def normalize_constraint_factor(
472 self, keys, min_accept=10000, sampling_chunk=50000, nrepeats=10
473 ):
474 if keys in self._cached_normalizations.keys():
475 return self._cached_normalizations[keys]
476 else:
477 factor_estimates = [
478 self._estimate_normalization(keys, min_accept, sampling_chunk)
479 for _ in range(nrepeats)
480 ]
481 factor = np.mean(factor_estimates)
482 if np.std(factor_estimates) > 0:
483 decimals = int(-np.floor(np.log10(3 * np.std(factor_estimates))))
484 factor_rounded = np.round(factor, decimals)
485 else:
486 factor_rounded = factor
487 self._cached_normalizations[keys] = factor_rounded
488 return factor_rounded
490 def _estimate_normalization(self, keys, min_accept, sampling_chunk):
491 samples = self.sample_subset(keys=keys, size=sampling_chunk)
492 keep = np.atleast_1d(self.evaluate_constraints(samples))
493 if len(keep) == 1:
494 self._cached_normalizations[keys] = 1
495 return 1
496 all_samples = {key: np.array([]) for key in keys}
497 while np.count_nonzero(keep) < min_accept:
498 samples = self.sample_subset(keys=keys, size=sampling_chunk)
499 for key in samples:
500 all_samples[key] = np.hstack([all_samples[key], samples[key].flatten()])
501 keep = np.array(self.evaluate_constraints(all_samples), dtype=bool)
502 factor = len(keep) / np.count_nonzero(keep)
503 return factor
505 def prob(self, sample, **kwargs):
506 """
508 Parameters
509 ==========
510 sample: dict
511 Dictionary of the samples of which we want to have the probability of
512 kwargs:
513 The keyword arguments are passed directly to `np.prod`
515 Returns
516 =======
517 float: Joint probability of all individual sample probabilities
519 """
520 prob = np.prod([self[key].prob(sample[key]) for key in sample], **kwargs)
522 return self.check_prob(sample, prob)
524 def check_prob(self, sample, prob):
525 ratio = self.normalize_constraint_factor(tuple(sample.keys()))
526 if np.all(prob == 0.0):
527 return prob * ratio
528 else:
529 if isinstance(prob, float):
530 if self.evaluate_constraints(sample):
531 return prob * ratio
532 else:
533 return 0.0
534 else:
535 constrained_prob = np.zeros_like(prob)
536 keep = np.array(self.evaluate_constraints(sample), dtype=bool)
537 constrained_prob[keep] = prob[keep] * ratio
538 return constrained_prob
540 def ln_prob(self, sample, axis=None, normalized=True):
541 """
543 Parameters
544 ==========
545 sample: dict
546 Dictionary of the samples of which to calculate the log probability
547 axis: None or int
548 Axis along which the summation is performed
549 normalized: bool
550 When False, disables calculation of constraint normalization factor
551 during prior probability computation. Default value is True.
553 Returns
554 =======
555 float or ndarray:
556 Joint log probability of all the individual sample probabilities
558 """
559 ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis)
560 return self.check_ln_prob(sample, ln_prob,
561 normalized=normalized)
563 def check_ln_prob(self, sample, ln_prob, normalized=True):
564 if normalized:
565 ratio = self.normalize_constraint_factor(tuple(sample.keys()))
566 else:
567 ratio = 1
568 if np.all(np.isinf(ln_prob)):
569 return ln_prob
570 else:
571 if isinstance(ln_prob, float):
572 if self.evaluate_constraints(sample):
573 return ln_prob + np.log(ratio)
574 else:
575 return -np.inf
576 else:
577 constrained_ln_prob = -np.inf * np.ones_like(ln_prob)
578 keep = np.array(self.evaluate_constraints(sample), dtype=bool)
579 constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio)
580 return constrained_ln_prob
582 def cdf(self, sample):
583 """Evaluate the cumulative distribution function at the provided points
585 Parameters
586 ----------
587 sample: dict, pandas.DataFrame
588 Dictionary of the samples of which to calculate the CDF
590 Returns
591 -------
592 dict, pandas.DataFrame: Dictionary containing the CDF values
594 """
595 return sample.__class__(
596 {key: self[key].cdf(sample) for key, sample in sample.items()}
597 )
599 def rescale(self, keys, theta):
600 """Rescale samples from unit cube to prior
602 Parameters
603 ==========
604 keys: list
605 List of prior keys to be rescaled
606 theta: list
607 List of randomly drawn values on a unit cube associated with the prior keys
609 Returns
610 =======
611 list: List of floats containing the rescaled sample
612 """
613 from matplotlib.cbook import flatten
615 return list(
616 flatten([self[key].rescale(sample) for key, sample in zip(keys, theta)])
617 )
619 def test_redundancy(self, key, disable_logging=False):
620 """Empty redundancy test, should be overwritten in subclasses"""
621 return False
623 def test_has_redundant_keys(self):
624 """
625 Test whether there are redundant keys in self.
627 Returns
628 =======
629 bool: Whether there are redundancies or not
630 """
631 redundant = False
632 for key in self:
633 if isinstance(self[key], Constraint):
634 continue
635 temp = self.copy()
636 del temp[key]
637 if temp.test_redundancy(key, disable_logging=True):
638 logger.warning(
639 "{} is a redundant key in this {}.".format(
640 key, self.__class__.__name__
641 )
642 )
643 redundant = True
644 return redundant
646 def copy(self):
647 """
648 We have to overwrite the copy method as it fails due to the presence of
649 defaults.
650 """
651 return self.__class__(dictionary=dict(self))
654class PriorDictException(Exception):
655 """General base class for all prior dict exceptions"""
658class ConditionalPriorDict(PriorDict):
659 def __init__(self, dictionary=None, filename=None, conversion_function=None):
660 """
662 Parameters
663 ==========
664 dictionary: dict
665 See parent class
666 filename: str
667 See parent class
668 """
669 self._conditional_keys = []
670 self._unconditional_keys = []
671 self._rescale_keys = []
672 self._rescale_indexes = []
673 self._least_recently_rescaled_keys = []
674 super(ConditionalPriorDict, self).__init__(
675 dictionary=dictionary,
676 filename=filename,
677 conversion_function=conversion_function,
678 )
679 self._resolved = False
680 self._resolve_conditions()
682 def _resolve_conditions(self):
683 """
684 Resolves how priors depend on each other and automatically
685 sorts them into the right order.
686 1. All unconditional priors are put in front in arbitrary order
687 2. We loop through all the unsorted conditional priors to find
688 which one can go next
689 3. We repeat step 2 len(self) number of times to make sure that
690 all conditional priors will be sorted in order
691 4. We set the `self._resolved` flag to True if all conditional
692 priors were added in the right order
693 """
694 self._unconditional_keys = [
695 key for key in self.keys() if not hasattr(self[key], "condition_func")
696 ]
697 conditional_keys_unsorted = [
698 key for key in self.keys() if hasattr(self[key], "condition_func")
699 ]
700 self._conditional_keys = []
701 for _ in range(len(self)):
702 for key in conditional_keys_unsorted[:]:
703 if self._check_conditions_resolved(key, self.sorted_keys):
704 self._conditional_keys.append(key)
705 conditional_keys_unsorted.remove(key)
707 self._resolved = True
708 if len(conditional_keys_unsorted) != 0:
709 self._resolved = False
711 def _check_conditions_resolved(self, key, sampled_keys):
712 """Checks if all required variables have already been sampled so we can sample this key"""
713 conditions_resolved = True
714 for k in self[key].required_variables:
715 if k not in sampled_keys:
716 conditions_resolved = False
717 return conditions_resolved
719 def sample_subset(self, keys=iter([]), size=None):
720 self.convert_floats_to_delta_functions()
721 add_delta_keys = [
722 key
723 for key in self.keys()
724 if key not in keys and isinstance(self[key], DeltaFunction)
725 ]
726 use_keys = add_delta_keys + list(keys)
727 subset_dict = ConditionalPriorDict({key: self[key] for key in use_keys})
728 if not subset_dict._resolved:
729 raise IllegalConditionsException(
730 "The current set of priors contains unresolvable conditions."
731 )
732 samples = dict()
733 for key in subset_dict.sorted_keys:
734 if key not in keys or isinstance(self[key], Constraint):
735 continue
736 if isinstance(self[key], Prior):
737 try:
738 samples[key] = subset_dict[key].sample(
739 size=size, **subset_dict.get_required_variables(key)
740 )
741 except ValueError:
742 # Some prior classes can not handle an array of conditional parameters (e.g. alpha for PowerLaw)
743 # If that is the case, we sample each sample individually.
744 required_variables = subset_dict.get_required_variables(key)
745 samples[key] = np.zeros(size)
746 for i in range(size):
747 rvars = {
748 key: value[i] for key, value in required_variables.items()
749 }
750 samples[key][i] = subset_dict[key].sample(**rvars)
751 else:
752 logger.debug("{} not a known prior.".format(key))
753 return samples
755 def get_required_variables(self, key):
756 """Returns the required variables to sample a given conditional key.
758 Parameters
759 ==========
760 key : str
761 Name of the key that we want to know the required variables for
763 Returns
764 =======
765 dict: key/value pairs of the required variables
766 """
767 return {
768 k: self[k].least_recently_sampled
769 for k in getattr(self[key], "required_variables", [])
770 }
772 def prob(self, sample, **kwargs):
773 """
775 Parameters
776 ==========
777 sample: dict
778 Dictionary of the samples of which we want to have the probability of
779 kwargs:
780 The keyword arguments are passed directly to `np.prod`
782 Returns
783 =======
784 float: Joint probability of all individual sample probabilities
786 """
787 self._prepare_evaluation(*zip(*sample.items()))
788 res = [
789 self[key].prob(sample[key], **self.get_required_variables(key))
790 for key in sample
791 ]
792 prob = np.prod(res, **kwargs)
793 return self.check_prob(sample, prob)
795 def ln_prob(self, sample, axis=None, normalized=True):
796 """
798 Parameters
799 ==========
800 sample: dict
801 Dictionary of the samples of which we want to have the log probability of
802 axis: Union[None, int]
803 Axis along which the summation is performed
804 normalized: bool
805 When False, disables calculation of constraint normalization factor
806 during prior probability computation. Default value is True.
808 Returns
809 =======
810 float: Joint log probability of all the individual sample probabilities
812 """
813 self._prepare_evaluation(*zip(*sample.items()))
814 res = [
815 self[key].ln_prob(sample[key], **self.get_required_variables(key))
816 for key in sample
817 ]
818 ln_prob = np.sum(res, axis=axis)
819 return self.check_ln_prob(sample, ln_prob,
820 normalized=normalized)
822 def cdf(self, sample):
823 self._prepare_evaluation(*zip(*sample.items()))
824 res = {
825 key: self[key].cdf(sample[key], **self.get_required_variables(key))
826 for key in sample
827 }
828 return sample.__class__(res)
830 def rescale(self, keys, theta):
831 """Rescale samples from unit cube to prior
833 Parameters
834 ==========
835 keys: list
836 List of prior keys to be rescaled
837 theta: list
838 List of randomly drawn values on a unit cube associated with the prior keys
840 Returns
841 =======
842 list: List of floats containing the rescaled sample
843 """
844 from matplotlib.cbook import flatten
846 keys = list(keys)
847 theta = list(theta)
848 self._check_resolved()
849 self._update_rescale_keys(keys)
850 result = dict()
851 for key, index in zip(
852 self.sorted_keys_without_fixed_parameters, self._rescale_indexes
853 ):
854 result[key] = self[key].rescale(
855 theta[index], **self.get_required_variables(key)
856 )
857 self[key].least_recently_sampled = result[key]
858 return list(flatten([result[key] for key in keys]))
860 def _update_rescale_keys(self, keys):
861 if not keys == self._least_recently_rescaled_keys:
862 self._rescale_indexes = [
863 keys.index(element)
864 for element in self.sorted_keys_without_fixed_parameters
865 ]
866 self._least_recently_rescaled_keys = keys
868 def _prepare_evaluation(self, keys, theta):
869 self._check_resolved()
870 for key, value in zip(keys, theta):
871 self[key].least_recently_sampled = value
873 def _check_resolved(self):
874 if not self._resolved:
875 raise IllegalConditionsException(
876 "The current set of priors contains unresolveable conditions."
877 )
879 @property
880 def conditional_keys(self):
881 return self._conditional_keys
883 @property
884 def unconditional_keys(self):
885 return self._unconditional_keys
887 @property
888 def sorted_keys(self):
889 return self.unconditional_keys + self.conditional_keys
891 @property
892 def sorted_keys_without_fixed_parameters(self):
893 return [
894 key
895 for key in self.sorted_keys
896 if not isinstance(self[key], (DeltaFunction, Constraint))
897 ]
899 def __setitem__(self, key, value):
900 super(ConditionalPriorDict, self).__setitem__(key, value)
901 self._resolve_conditions()
903 def __delitem__(self, key):
904 super(ConditionalPriorDict, self).__delitem__(key)
905 self._resolve_conditions()
908class DirichletPriorDict(ConditionalPriorDict):
909 def __init__(self, n_dim=None, label="dirichlet_"):
910 from .conditional import DirichletElement
912 self.n_dim = n_dim
913 self.label = label
914 super(DirichletPriorDict, self).__init__(dictionary=dict())
915 for ii in range(n_dim - 1):
916 self[label + "{}".format(ii)] = DirichletElement(
917 order=ii, n_dimensions=n_dim, label=label
918 )
920 def copy(self, **kwargs):
921 return self.__class__(n_dim=self.n_dim, label=self.label)
923 def _get_json_dict(self):
924 total_dict = dict()
925 total_dict["__prior_dict__"] = True
926 total_dict["__module__"] = self.__module__
927 total_dict["__name__"] = self.__class__.__name__
928 total_dict["n_dim"] = self.n_dim
929 total_dict["label"] = self.label
930 return total_dict
932 @classmethod
933 def _get_from_json_dict(cls, prior_dict):
934 try:
935 cls == getattr(
936 import_module(prior_dict["__module__"]), prior_dict["__name__"]
937 )
938 except ImportError:
939 logger.debug(
940 "Cannot import prior module {}.{}".format(
941 prior_dict["__module__"], prior_dict["__name__"]
942 )
943 )
944 except KeyError:
945 logger.debug("Cannot find module name to load")
946 for key in ["__module__", "__name__", "__prior_dict__"]:
947 if key in prior_dict:
948 del prior_dict[key]
949 obj = cls(**prior_dict)
950 return obj
953class ConditionalPriorDictException(PriorDictException):
954 """General base class for all conditional prior dict exceptions"""
957def create_default_prior(name, default_priors_file=None):
958 """Make a default prior for a parameter with a known name.
960 Parameters
961 ==========
962 name: str
963 Parameter name
964 default_priors_file: str, optional
965 If given, a file containing the default priors.
967 Returns
968 =======
969 prior: Prior
970 Default prior distribution for that parameter, if unknown None is
971 returned.
972 """
974 if default_priors_file is None:
975 logger.debug("No prior file given.")
976 prior = None
977 else:
978 default_priors = PriorDict(filename=default_priors_file)
979 if name in default_priors.keys():
980 prior = default_priors[name]
981 else:
982 logger.debug("No default prior found for variable {}.".format(name))
983 prior = None
984 return prior
987class IllegalConditionsException(ConditionalPriorDictException):
988 """Exception class to handle prior dicts that contain unresolvable conditions."""