Coverage for bilby/core/prior/joint.py: 85%
408 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import re
3import numpy as np
4import scipy.stats
5from scipy.special import erfinv
7from .base import Prior, PriorException
8from ..utils import logger, infer_args_from_method, get_dict_with_properties
9from ..utils import random
12class BaseJointPriorDist(object):
13 def __init__(self, names, bounds=None):
14 """
15 A class defining JointPriorDist that will be overwritten with child
16 classes defining the joint prior distributions between given parameters,
19 Parameters
20 ==========
21 names: list (required)
22 A list of the parameter names in the JointPriorDist. The
23 listed parameters must have the same order that they appear in
24 the lists of statistical parameters that may be passed in child class
25 bounds: list (optional)
26 A list of bounds on each parameter. The defaults are for bounds at
27 +/- infinity.
28 """
29 self.distname = "joint_dist"
30 if not isinstance(names, list):
31 self.names = [names]
32 else:
33 self.names = names
35 self.num_vars = len(self.names)
37 # set the bounds for each parameter
38 if isinstance(bounds, list):
39 if len(bounds) != len(self):
40 raise ValueError("Wrong number of parameter bounds")
42 # check bounds
43 for bound in bounds:
44 if isinstance(bounds, (list, tuple, np.ndarray)):
45 if len(bound) != 2:
46 raise ValueError(
47 "Bounds must contain an upper and lower value."
48 )
49 else:
50 if bound[1] <= bound[0]:
51 raise ValueError("Bounds are not properly set")
52 else:
53 raise TypeError("Bound must be a list")
54 else:
55 bounds = [(-np.inf, np.inf) for _ in self.names]
56 self.bounds = {name: val for name, val in zip(self.names, bounds)}
58 self._current_sample = {} # initialise empty sample
59 self._uncorrelated = None
60 self._current_lnprob = None
62 # a dictionary of the parameters as requested by the prior
63 self.requested_parameters = dict()
64 self.reset_request()
66 # a dictionary of the rescaled parameters
67 self.rescale_parameters = dict()
68 self.reset_rescale()
70 # a list of sampled parameters
71 self.reset_sampled()
73 def reset_sampled(self):
74 self.sampled_parameters = []
75 self.current_sample = {}
77 def filled_request(self):
78 """
79 Check if all requested parameters have been filled.
80 """
82 return not np.any([val is None for val in self.requested_parameters.values()])
84 def reset_request(self):
85 """
86 Reset the requested parameters to None.
87 """
89 for name in self.names:
90 self.requested_parameters[name] = None
92 def filled_rescale(self):
93 """
94 Check if all the rescaled parameters have been filled.
95 """
97 return not np.any([val is None for val in self.rescale_parameters.values()])
99 def reset_rescale(self):
100 """
101 Reset the rescaled parameters to None.
102 """
104 for name in self.names:
105 self.rescale_parameters[name] = None
107 def get_instantiation_dict(self):
108 subclass_args = infer_args_from_method(self.__init__)
109 dict_with_properties = get_dict_with_properties(self)
110 instantiation_dict = dict()
111 for key in subclass_args:
112 if isinstance(dict_with_properties[key], list):
113 value = np.asarray(dict_with_properties[key]).tolist()
114 else:
115 value = dict_with_properties[key]
116 instantiation_dict[key] = value
117 return instantiation_dict
119 def __len__(self):
120 return len(self.names)
122 def __repr__(self):
123 """Overrides the special method __repr__.
125 Returns a representation of this instance that resembles how it is instantiated.
126 Works correctly for all child classes
128 Returns
129 =======
130 str: A string representation of this instance
132 """
133 dist_name = self.__class__.__name__
134 instantiation_dict = self.get_instantiation_dict()
135 args = ", ".join(
136 [
137 "{}={}".format(key, repr(instantiation_dict[key]))
138 for key in instantiation_dict
139 ]
140 )
141 return "{}({})".format(dist_name, args)
143 def prob(self, samp):
144 """
145 Get the probability of a sample. For bounded priors the
146 probability will not be properly normalised.
147 """
149 return np.exp(self.ln_prob(samp))
151 def _check_samp(self, value):
152 """
153 Get the log-probability of a sample. For bounded priors the
154 probability will not be properly normalised.
156 Parameters
157 ==========
158 value: array_like
159 A 1d vector of the sample, or 2d array of sample values with shape
160 NxM, where N is the number of samples and M is the number of
161 parameters.
163 Returns
164 =======
165 samp: array_like
166 returns the input value as a sample array
167 outbounds: array_like
168 Boolean Array that selects samples in samp that are out of given bounds
169 """
170 samp = np.array(value)
171 if len(samp.shape) == 1:
172 samp = samp.reshape(1, self.num_vars)
174 if len(samp.shape) != 2:
175 raise ValueError("Array is the wrong shape")
176 elif samp.shape[1] != self.num_vars:
177 raise ValueError("Array is the wrong shape")
179 # check sample(s) is within bounds
180 outbounds = np.ones(samp.shape[0], dtype=bool)
181 for s, bound in zip(samp.T, self.bounds.values()):
182 outbounds = (s < bound[0]) | (s > bound[1])
183 if np.any(outbounds):
184 break
185 return samp, outbounds
187 def ln_prob(self, value):
188 """
189 Get the log-probability of a sample. For bounded priors the
190 probability will not be properly normalised.
192 Parameters
193 ==========
194 value: array_like
195 A 1d vector of the sample, or 2d array of sample values with shape
196 NxM, where N is the number of samples and M is the number of
197 parameters.
198 """
200 samp, outbounds = self._check_samp(value)
201 lnprob = -np.inf * np.ones(samp.shape[0])
202 lnprob = self._ln_prob(samp, lnprob, outbounds)
203 if samp.shape[0] == 1:
204 return lnprob[0]
205 else:
206 return lnprob
208 def _ln_prob(self, samp, lnprob, outbounds):
209 """
210 Get the log-probability of a sample. For bounded priors the
211 probability will not be properly normalised. **this method needs overwritten by child class**
213 Parameters
214 ==========
215 samp: vector
216 sample to evaluate the ln_prob at
217 lnprob: vector
218 of -inf passed in with the same shape as the number of samples
219 outbounds: array_like
220 boolean array showing which samples in lnprob vector are out of the given bounds
222 Returns
223 =======
224 lnprob: vector
225 array of lnprob values for each sample given
226 """
227 """
228 Here is where the subclass where overwrite ln_prob method
229 """
230 return lnprob
232 def sample(self, size=1, **kwargs):
233 """
234 Draw, and set, a sample from the Dist, accompanying method _sample needs to overwritten
236 Parameters
237 ==========
238 size: int
239 number of samples to generate, defaults to 1
240 """
242 if size is None:
243 size = 1
244 samps = self._sample(size=size, **kwargs)
245 for i, name in enumerate(self.names):
246 if size == 1:
247 self.current_sample[name] = samps[:, i].flatten()[0]
248 else:
249 self.current_sample[name] = samps[:, i].flatten()
251 def _sample(self, size, **kwargs):
252 """
253 Draw, and set, a sample from the joint dist (**needs to be ovewritten by child class**)
255 Parameters
256 ==========
257 size: int
258 number of samples to generate, defaults to 1
259 """
260 samps = np.zeros((size, len(self)))
261 """
262 Here is where the subclass where overwrite sampling method
263 """
264 return samps
266 def rescale(self, value, **kwargs):
267 """
268 Rescale from a unit hypercube to JointPriorDist. Note that no
269 bounds are applied in the rescale function. (child classes need to
270 overwrite accompanying method _rescale().
272 Parameters
273 ==========
274 value: array
275 A 1d vector sample (one for each parameter) drawn from a uniform
276 distribution between 0 and 1, or a 2d NxM array of samples where
277 N is the number of samples and M is the number of parameters.
278 kwargs: dict
279 All keyword args that need to be passed to _rescale method, these keyword
280 args are called in the JointPrior rescale methods for each parameter
282 Returns
283 =======
284 array:
285 An vector sample drawn from the multivariate Gaussian
286 distribution.
287 """
288 samp = np.array(value)
289 if len(samp.shape) == 1:
290 samp = samp.reshape(1, self.num_vars)
292 if len(samp.shape) != 2:
293 raise ValueError("Array is the wrong shape")
294 elif samp.shape[1] != self.num_vars:
295 raise ValueError("Array is the wrong shape")
297 samp = self._rescale(samp, **kwargs)
298 return np.squeeze(samp)
300 def _rescale(self, samp, **kwargs):
301 """
302 rescale a sample from a unit hypercybe to the joint dist (**needs to be ovewritten by child class**)
304 Parameters
305 ==========
306 samp: numpy array
307 this is a vector sample drawn from a uniform distribution to be rescaled to the distribution
308 """
309 """
310 Here is where the subclass where overwrite rescale method
311 """
312 return samp
315class MultivariateGaussianDist(BaseJointPriorDist):
316 def __init__(
317 self,
318 names,
319 nmodes=1,
320 mus=None,
321 sigmas=None,
322 corrcoefs=None,
323 covs=None,
324 weights=None,
325 bounds=None,
326 ):
327 """
328 A class defining a multi-variate Gaussian, allowing multiple modes for
329 a Gaussian mixture model.
331 Note: if using a multivariate Gaussian prior, with bounds, this can
332 lead to biases in the marginal likelihood estimate and posterior
333 estimate for nested samplers routines that rely on sampling from a unit
334 hypercube and having a prior transform, e.g., nestle, dynesty and
335 MultiNest.
337 Parameters
338 ==========
339 names: list
340 A list of the parameter names in the multivariate Gaussian. The
341 listed parameters must have the same order that they appear in
342 the lists of means, standard deviations, and the correlation
343 coefficient, or covariance, matrices.
344 nmodes: int
345 The number of modes for the mixture model. This defaults to 1,
346 which will be checked against the shape of the other inputs.
347 mus: array_like
348 A list of lists of means of each mode in a multivariate Gaussian
349 mixture model. A single list can be given for a single mode. If
350 this is None then means at zero will be assumed.
351 sigmas: array_like
352 A list of lists of the standard deviations of each mode of the
353 multivariate Gaussian. If supplying a correlation coefficient
354 matrix rather than a covariance matrix these values must be given.
355 If this is None unit variances will be assumed.
356 corrcoefs: array
357 A list of square matrices containing the correlation coefficients
358 of the parameters for each mode. If this is None it will be assumed
359 that the parameters are uncorrelated.
360 covs: array
361 A list of square matrices containing the covariance matrix of the
362 multivariate Gaussian.
363 weights: list
364 A list of weights (relative probabilities) for each mode of the
365 multivariate Gaussian. This will default to equal weights for each
366 mode.
367 bounds: list
368 A list of bounds on each parameter. The defaults are for bounds at
369 +/- infinity.
370 """
371 super(MultivariateGaussianDist, self).__init__(names=names, bounds=bounds)
372 for name in self.names:
373 bound = self.bounds[name]
374 if bound[0] != -np.inf or bound[1] != np.inf:
375 logger.warning(
376 "If using bounded ranges on the multivariate "
377 "Gaussian this will lead to biased posteriors "
378 "for nested sampling routines that require "
379 "a prior transform."
380 )
381 self.distname = "mvg"
382 self.mus = []
383 self.covs = []
384 self.corrcoefs = []
385 self.sigmas = []
386 self.logprodsigmas = [] # log of product of sigmas, needed for "standard" multivariate normal
387 self.weights = []
388 self.eigvalues = []
389 self.eigvectors = []
390 self.sqeigvalues = [] # square root of the eigenvalues
391 self.mvn = [] # list of multivariate normal distributions
393 # put values in lists if required
394 if nmodes == 1:
395 if mus is not None:
396 if len(np.shape(mus)) == 1:
397 mus = [mus]
398 elif len(np.shape(mus)) == 0:
399 raise ValueError("Must supply a list of means")
400 if sigmas is not None:
401 if len(np.shape(sigmas)) == 1:
402 sigmas = [sigmas]
403 elif len(np.shape(sigmas)) == 0:
404 raise ValueError("Must supply a list of standard deviations")
405 if covs is not None:
406 if isinstance(covs, np.ndarray):
407 covs = [covs]
408 elif isinstance(covs, list):
409 if len(np.shape(covs)) == 2:
410 covs = [np.array(covs)]
411 elif len(np.shape(covs)) != 3:
412 raise TypeError("List of covariances the wrong shape")
413 else:
414 raise TypeError("Must pass a list of covariances")
415 if corrcoefs is not None:
416 if isinstance(corrcoefs, np.ndarray):
417 corrcoefs = [corrcoefs]
418 elif isinstance(corrcoefs, list):
419 if len(np.shape(corrcoefs)) == 2:
420 corrcoefs = [np.array(corrcoefs)]
421 elif len(np.shape(corrcoefs)) != 3:
422 raise TypeError(
423 "List of correlation coefficients the wrong shape"
424 )
425 elif not isinstance(corrcoefs, list):
426 raise TypeError("Must pass a list of correlation coefficients")
427 if weights is not None:
428 if isinstance(weights, (int, float)):
429 weights = [weights]
430 elif isinstance(weights, list):
431 if len(weights) != 1:
432 raise ValueError("Wrong number of weights given")
434 for val in [mus, sigmas, covs, corrcoefs, weights]:
435 if val is not None and not isinstance(val, list):
436 raise TypeError("Value must be a list")
437 else:
438 if val is not None and len(val) != nmodes:
439 raise ValueError("Wrong number of modes given")
441 # add the modes
442 self.nmodes = 0
443 for i in range(nmodes):
444 mu = mus[i] if mus is not None else None
445 sigma = sigmas[i] if sigmas is not None else None
446 corrcoef = corrcoefs[i] if corrcoefs is not None else None
447 cov = covs[i] if covs is not None else None
448 weight = weights[i] if weights is not None else 1.0
450 self.add_mode(mu, sigma, corrcoef, cov, weight)
452 def add_mode(self, mus=None, sigmas=None, corrcoef=None, cov=None, weight=1.0):
453 """
454 Add a new mode.
455 """
457 # add means
458 if mus is not None:
459 try:
460 self.mus.append(list(mus)) # means
461 except TypeError:
462 raise TypeError("'mus' must be a list")
463 else:
464 self.mus.append(np.zeros(self.num_vars))
466 # add the covariances if supplied
467 if cov is not None:
468 self.covs.append(np.asarray(cov))
470 if len(self.covs[-1].shape) != 2:
471 raise ValueError("Covariance matrix must be a 2d array")
473 if (
474 self.covs[-1].shape[0] != self.covs[-1].shape[1]
475 or self.covs[-1].shape[0] != self.num_vars
476 ):
477 raise ValueError("Covariance shape is inconsistent")
479 # check matrix is symmetric
480 if not np.allclose(self.covs[-1], self.covs[-1].T):
481 raise ValueError("Covariance matrix is not symmetric")
483 self.sigmas.append(np.sqrt(np.diag(self.covs[-1]))) # standard deviations
485 # convert covariance into a correlation coefficient matrix
486 D = self.sigmas[-1] * np.identity(self.covs[-1].shape[0])
487 Dinv = np.linalg.inv(D)
488 self.corrcoefs.append(np.dot(np.dot(Dinv, self.covs[-1]), Dinv))
489 elif corrcoef is not None and sigmas is not None:
490 self.corrcoefs.append(np.asarray(corrcoef))
492 if len(self.corrcoefs[-1].shape) != 2:
493 raise ValueError(
494 "Correlation coefficient matrix must be a 2d array."
495 )
497 if (
498 self.corrcoefs[-1].shape[0] != self.corrcoefs[-1].shape[1]
499 or self.corrcoefs[-1].shape[0] != self.num_vars
500 ):
501 raise ValueError(
502 "Correlation coefficient matrix shape is inconsistent"
503 )
505 # check matrix is symmetric
506 if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T):
507 raise ValueError("Correlation coefficient matrix is not symmetric")
509 # check diagonal is all ones
510 if not np.all(np.diag(self.corrcoefs[-1]) == 1.0):
511 raise ValueError("Correlation coefficient matrix is not correct")
513 try:
514 self.sigmas.append(list(sigmas)) # standard deviations
515 except TypeError:
516 raise TypeError("'sigmas' must be a list")
518 if len(self.sigmas[-1]) != self.num_vars:
519 raise ValueError(
520 "Number of standard deviations must be the "
521 "same as the number of parameters."
522 )
524 # convert correlation coefficients to covariance matrix
525 D = self.sigmas[-1] * np.identity(self.corrcoefs[-1].shape[0])
526 self.covs.append(np.dot(D, np.dot(self.corrcoefs[-1], D)))
527 else:
528 # set unit variance uncorrelated covariance
529 self.corrcoefs.append(np.eye(self.num_vars))
530 self.covs.append(np.eye(self.num_vars))
531 self.sigmas.append(np.ones(self.num_vars))
533 # compute log of product of sigmas, needed for "standard" multivariate normal
534 self.logprodsigmas.append(np.log(np.prod(self.sigmas[-1])))
536 # get eigen values and vectors
537 try:
538 evals, evecs = np.linalg.eig(self.corrcoefs[-1])
539 self.eigvalues.append(evals)
540 self.eigvectors.append(evecs)
541 except Exception as e:
542 raise RuntimeError(
543 "Problem getting eigenvalues and vectors: {}".format(e)
544 )
546 # check eigenvalues are positive
547 if np.any(self.eigvalues[-1] <= 0.0):
548 raise ValueError(
549 "Correlation coefficient matrix is not positive definite"
550 )
551 self.sqeigvalues.append(np.sqrt(self.eigvalues[-1]))
553 # set the weights
554 if weight is None:
555 self.weights.append(1.0)
556 else:
557 self.weights.append(weight)
559 # set the cumulative relative weights
560 self.cumweights = np.cumsum(self.weights) / np.sum(self.weights)
562 # add the mode
563 self.nmodes += 1
565 # add "standard" multivariate normal distribution
566 # - when the typical scales of the parameters are very different,
567 # multivariate_normal() may complain that the covariance matrix is singular
568 # - instead pass zero means and correlation matrix instead of covariance matrix
569 # to get the equivalent of a standard normal distribution in higher dimensions
570 # - this modifies the multivariate normal PDF as follows:
571 # multivariate_normal(mean=mus, cov=cov).logpdf(x)
572 # = multivariate_normal(mean=0, cov=corrcoefs).logpdf((x - mus)/sigmas) - logprodsigmas
573 self.mvn.append(
574 scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1])
575 )
577 def _rescale(self, samp, **kwargs):
578 try:
579 mode = kwargs["mode"]
580 except KeyError:
581 mode = None
583 if mode is None:
584 if self.nmodes == 1:
585 mode = 0
586 else:
587 mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0]
589 samp = erfinv(2.0 * samp - 1) * 2.0 ** 0.5
591 # rotate and scale to the multivariate normal shape
592 samp = self.mus[mode] + self.sigmas[mode] * np.einsum(
593 "ij,kj->ik", samp * self.sqeigvalues[mode], self.eigvectors[mode]
594 )
595 return samp
597 def _sample(self, size, **kwargs):
598 try:
599 mode = kwargs["mode"]
600 except KeyError:
601 mode = None
603 if mode is None:
604 if self.nmodes == 1:
605 mode = 0
606 else:
607 if size == 1:
608 mode = np.argwhere(self.cumweights - random.rng.uniform(0, 1) > 0)[0][0]
609 else:
610 # pick modes
611 mode = [
612 np.argwhere(self.cumweights - r > 0)[0][0]
613 for r in random.rng.uniform(0, 1, size)
614 ]
616 samps = np.zeros((size, len(self)))
617 for i in range(size):
618 inbound = False
619 while not inbound:
620 # sample the multivariate Gaussian keys
621 vals = random.rng.uniform(0, 1, len(self))
623 if isinstance(mode, list):
624 samp = np.atleast_1d(self.rescale(vals, mode=mode[i]))
625 else:
626 samp = np.atleast_1d(self.rescale(vals, mode=mode))
627 samps[i, :] = samp
629 # check sample is in bounds (otherwise perform another draw)
630 outbound = False
631 for name, val in zip(self.names, samp):
632 if val < self.bounds[name][0] or val > self.bounds[name][1]:
633 outbound = True
634 break
636 if not outbound:
637 inbound = True
639 return samps
641 def _ln_prob(self, samp, lnprob, outbounds):
642 for j in range(samp.shape[0]):
643 # loop over the modes and sum the probabilities
644 for i in range(self.nmodes):
645 # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode()
646 z = (samp[j] - self.mus[i]) / self.sigmas[i]
647 lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i])
649 # set out-of-bounds values to -inf
650 lnprob[outbounds] = -np.inf
651 return lnprob
653 def __eq__(self, other):
654 if self.__class__ != other.__class__:
655 return False
656 if sorted(self.__dict__.keys()) != sorted(other.__dict__.keys()):
657 return False
658 for key in self.__dict__:
659 if key == "mvn":
660 if len(self.__dict__[key]) != len(other.__dict__[key]):
661 return False
662 for thismvn, othermvn in zip(self.__dict__[key], other.__dict__[key]):
663 if not isinstance(
664 thismvn, scipy.stats._multivariate.multivariate_normal_frozen
665 ) or not isinstance(
666 othermvn, scipy.stats._multivariate.multivariate_normal_frozen
667 ):
668 return False
669 elif isinstance(self.__dict__[key], (np.ndarray, list)):
670 thisarr = np.asarray(self.__dict__[key])
671 otherarr = np.asarray(other.__dict__[key])
672 if thisarr.dtype == float and otherarr.dtype == float:
673 fin1 = np.isfinite(np.asarray(self.__dict__[key]))
674 fin2 = np.isfinite(np.asarray(other.__dict__[key]))
675 if not np.array_equal(fin1, fin2):
676 return False
677 if not np.allclose(thisarr[fin1], otherarr[fin2], atol=1e-15):
678 return False
679 else:
680 if not np.array_equal(thisarr, otherarr):
681 return False
682 else:
683 if not self.__dict__[key] == other.__dict__[key]:
684 return False
685 return True
687 @classmethod
688 def from_repr(cls, string):
689 """Generate the distribution from its __repr__"""
690 return cls._from_repr(string)
692 @classmethod
693 def _from_repr(cls, string):
694 subclass_args = infer_args_from_method(cls.__init__)
696 string = string.replace(" ", "")
697 kwargs = cls._split_repr(string)
698 for key in kwargs:
699 val = kwargs[key]
700 if key not in subclass_args:
701 raise AttributeError(
702 "Unknown argument {} for class {}".format(key, cls.__name__)
703 )
704 else:
705 kwargs[key.strip()] = Prior._parse_argument_string(val)
707 return cls(**kwargs)
709 @classmethod
710 def _split_repr(cls, string):
711 string = string.replace(",", ", ")
712 # see https://stackoverflow.com/a/72146415/1862861
713 args = re.findall(r"(\w+)=(\[.*?]|{.*?}|\S+)(?=\s*,\s*\w+=|\Z)", string)
714 kwargs = dict()
715 for key, arg in args:
716 kwargs[key.strip()] = arg
717 return kwargs
720class MultivariateNormalDist(MultivariateGaussianDist):
721 """A synonym for the :class:`~bilby.core.prior.MultivariateGaussianDist` distribution."""
724class JointPrior(Prior):
725 def __init__(self, dist, name=None, latex_label=None, unit=None):
726 """This defines the single parameter Prior object for parameters that belong to a JointPriorDist
728 Parameters
729 ==========
730 dist: ChildClass of BaseJointPriorDist
731 The shared JointPriorDistribution that this parameter belongs to
732 name: str
733 Name of this parameter. Must be contained in dist.names
734 latex_label: str
735 See superclass
736 unit: str
737 See superclass
738 """
739 if BaseJointPriorDist not in dist.__class__.__bases__:
740 raise TypeError(
741 "Must supply a JointPriorDist object instance to be shared by all joint params"
742 )
744 if name not in dist.names:
745 raise ValueError(
746 "'{}' is not a parameter in the JointPriorDist".format(name)
747 )
749 self.dist = dist
750 super(JointPrior, self).__init__(
751 name=name,
752 latex_label=latex_label,
753 unit=unit,
754 minimum=dist.bounds[name][0],
755 maximum=dist.bounds[name][1],
756 )
758 @property
759 def minimum(self):
760 return self._minimum
762 @minimum.setter
763 def minimum(self, minimum):
764 self._minimum = minimum
765 self.dist.bounds[self.name] = (minimum, self.dist.bounds[self.name][1])
767 @property
768 def maximum(self):
769 return self._maximum
771 @maximum.setter
772 def maximum(self, maximum):
773 self._maximum = maximum
774 self.dist.bounds[self.name] = (self.dist.bounds[self.name][0], maximum)
776 def rescale(self, val, **kwargs):
777 """
778 Scale a unit hypercube sample to the prior.
780 Parameters
781 ==========
782 val: array_like
783 value drawn from unit hypercube to be rescaled onto the prior
784 kwargs: dict
785 all kwargs passed to the dist.rescale method
786 Returns
787 =======
788 float:
789 A sample from the prior parameter.
790 """
792 self.dist.rescale_parameters[self.name] = val
794 if self.dist.filled_rescale():
795 values = np.array(list(self.dist.rescale_parameters.values())).T
796 samples = self.dist.rescale(values, **kwargs)
797 self.dist.reset_rescale()
798 return samples
799 else:
800 return [] # return empty list
802 def sample(self, size=1, **kwargs):
803 """
804 Draw a sample from the prior.
806 Parameters
807 ==========
808 size: int, float (defaults to 1)
809 number of samples to draw
810 kwargs: dict
811 kwargs passed to the dist.sample method
812 Returns
813 =======
814 float:
815 A sample from the prior parameter.
816 """
818 if self.name in self.dist.sampled_parameters:
819 logger.warning(
820 "You have already drawn a sample from parameter "
821 "'{}'. The same sample will be "
822 "returned".format(self.name)
823 )
825 if len(self.dist.current_sample) == 0:
826 # generate a sample
827 self.dist.sample(size=size, **kwargs)
829 sample = self.dist.current_sample[self.name]
831 if self.name not in self.dist.sampled_parameters:
832 self.dist.sampled_parameters.append(self.name)
834 if len(self.dist.sampled_parameters) == len(self.dist):
835 # reset samples
836 self.dist.reset_sampled()
837 self.least_recently_sampled = sample
838 return sample
840 def ln_prob(self, val):
841 """
842 Return the natural logarithm of the prior probability. Note that this
843 will not be correctly normalised if there are bounds on the
844 distribution.
846 Parameters
847 ==========
848 val: array_like
849 value to evaluate the prior log-prob at
850 Returns
851 =======
852 float:
853 the logp value for the prior at given sample
854 """
855 self.dist.requested_parameters[self.name] = val
857 if self.dist.filled_request():
858 # all required parameters have been set
859 values = list(self.dist.requested_parameters.values())
861 # check for the same number of values for each parameter
862 for i in range(len(self.dist) - 1):
863 if isinstance(values[i], (list, np.ndarray)) or isinstance(
864 values[i + 1], (list, np.ndarray)
865 ):
866 if isinstance(values[i], (list, np.ndarray)) and isinstance(
867 values[i + 1], (list, np.ndarray)
868 ):
869 if len(values[i]) != len(values[i + 1]):
870 raise ValueError(
871 "Each parameter must have the same "
872 "number of requested values."
873 )
874 else:
875 raise ValueError(
876 "Each parameter must have the same "
877 "number of requested values."
878 )
880 lnp = self.dist.ln_prob(np.asarray(values).T)
882 # reset the requested parameters
883 self.dist.reset_request()
884 return lnp
885 else:
886 # if not all parameters have been requested yet, just return 0
887 if isinstance(val, (float, int)):
888 return 0.0
889 else:
890 try:
891 # check value has a length
892 len(val)
893 except Exception as e:
894 raise TypeError("Invalid type for ln_prob: {}".format(e))
896 if len(val) == 1:
897 return 0.0
898 else:
899 return np.zeros_like(val)
901 def prob(self, val):
902 """Return the prior probability of val
904 Parameters
905 ==========
906 val: array_like
907 value to evaluate the prior prob at
909 Returns
910 =======
911 float:
912 the p value for the prior at given sample
913 """
915 return np.exp(self.ln_prob(val))
918class MultivariateGaussian(JointPrior):
919 def __init__(self, dist, name=None, latex_label=None, unit=None):
920 if not isinstance(dist, MultivariateGaussianDist):
921 raise JointPriorDistError(
922 "dist object must be instance of MultivariateGaussianDist"
923 )
924 super(MultivariateGaussian, self).__init__(
925 dist=dist, name=name, latex_label=latex_label, unit=unit
926 )
929class MultivariateNormal(MultivariateGaussian):
930 """A synonym for the :class:`bilby.core.prior.MultivariateGaussian`
931 prior distribution."""
934class JointPriorDistError(PriorException):
935 """Class for Error handling of JointPriorDists for JointPriors"""