Coverage for bilby/core/prior/slabspike.py: 90%
83 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1from numbers import Number
2import numpy as np
4from .base import Prior
5from ..utils import logger
8class SlabSpikePrior(Prior):
10 def __init__(self, slab, spike_location=None, spike_height=0):
11 """'Slab-and-spike' prior, see e.g. https://arxiv.org/abs/1812.07259
12 This prior is composed of a `slab`, i.e. any common prior distribution,
13 and a Dirac spike at a fixed location. This can effectively be used
14 to emulate sampling in the number of dimensions (similar to reversible-
15 jump MCMC).
17 `SymmetricLogUniform` and `FermiDirac` are currently not supported.
19 Parameters
20 ==========
21 slab: Prior
22 Any instance of a bilby prior class. All general prior attributes
23 from the slab are copied into the SlabSpikePrior.
24 Note that this hasn't been tested for conditional priors.
25 spike_location: float, optional
26 Location of the Dirac spike. Must be between minimum and maximum
27 of the slab. Defaults to the minimum of the slab
28 spike_height: float, optional
29 Relative weight of the spike compared to the slab. Must be
30 between 0 and 1. Defaults to 0, i.e. the prior is just the slab.
32 """
33 self.slab = slab
34 super().__init__(name=self.slab.name, latex_label=self.slab.latex_label, unit=self.slab.unit,
35 minimum=self.slab.minimum, maximum=self.slab.maximum,
36 check_range_nonzero=self.slab.check_range_nonzero, boundary=self.slab.boundary)
37 self.spike_location = spike_location
38 self.spike_height = spike_height
39 try:
40 self.inverse_cdf_below_spike = self._find_inverse_cdf_fraction_before_spike()
41 except Exception as e:
42 logger.warning("Disregard the following warning when running tests:\n {}".format(e))
44 @property
45 def spike_location(self):
46 return self._spike_loc
48 @spike_location.setter
49 def spike_location(self, spike_loc):
50 if spike_loc is None:
51 spike_loc = self.minimum
52 if not self.minimum <= spike_loc <= self.maximum:
53 raise ValueError("Spike location {} not within prior domain ".format(spike_loc))
54 self._spike_loc = spike_loc
56 @property
57 def spike_height(self):
58 return self._spike_height
60 @spike_height.setter
61 def spike_height(self, spike_height):
62 if 0 <= spike_height <= 1:
63 self._spike_height = spike_height
64 else:
65 raise ValueError("Spike height must be between 0 and 1, but is {}".format(spike_height))
67 @property
68 def slab_fraction(self):
69 """ Relative prior weight of the slab. """
70 return 1 - self.spike_height
72 def _find_inverse_cdf_fraction_before_spike(self):
73 return float(self.slab.cdf(self.spike_location)) * self.slab_fraction
75 def rescale(self, val):
76 """
77 'Rescale' a sample from the unit line element to the prior.
79 Parameters
80 ==========
81 val: Union[float, int, array_like]
82 A random number between 0 and 1
84 Returns
85 =======
86 array_like: Associated prior value with input value.
87 """
88 original_is_number = isinstance(val, Number)
89 val = np.atleast_1d(val)
91 lower_indices = np.where(val < self.inverse_cdf_below_spike)[0]
92 intermediate_indices = np.where(np.logical_and(
93 self.inverse_cdf_below_spike <= val,
94 val <= self.inverse_cdf_below_spike + self.spike_height))[0]
95 higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0]
97 res = np.zeros(len(val))
98 res[lower_indices] = self._contracted_rescale(val[lower_indices])
99 res[intermediate_indices] = self.spike_location
100 res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height)
101 if original_is_number:
102 try:
103 res = res[0]
104 except (KeyError, TypeError):
105 logger.warning("Based on inputs, a number should be output\
106 but this could not be accessed from what was computed")
107 return res
109 def _contracted_rescale(self, val):
110 """
111 Contracted version of the rescale function that implements the `rescale` function
112 on the pure slab part of the prior.
114 Parameters
115 ==========
116 val: Union[float, int, array_like]
117 A random number between 0 and self.slab_fraction
119 Returns
120 =======
121 array_like: Associated prior value with input value.
122 """
123 return self.slab.rescale(val / self.slab_fraction)
125 def prob(self, val):
126 """Return the prior probability of val.
127 Returns np.inf for the spike location
129 Parameters
130 ==========
131 val: Union[float, int, array_like]
133 Returns
134 =======
135 array_like: Prior probability of val
136 """
137 original_is_number = isinstance(val, Number)
138 res = self.slab.prob(val) * self.slab_fraction
139 res = np.atleast_1d(res)
140 res[np.where(val == self.spike_location)] = np.inf
141 if original_is_number:
142 try:
143 res = res[0]
144 except (KeyError, TypeError):
145 logger.warning("Based on inputs, a number should be output\
146 but this could not be accessed from what was computed")
147 return res
149 def ln_prob(self, val):
150 """Return the Log prior probability of val.
151 Returns np.inf for the spike location
153 Parameters
154 ==========
155 val: Union[float, int, array_like]
157 Returns
158 =======
159 array_like: Prior probability of val
160 """
161 original_is_number = isinstance(val, Number)
162 res = self.slab.ln_prob(val) + np.log(self.slab_fraction)
163 res = np.atleast_1d(res)
164 res[np.where(val == self.spike_location)] = np.inf
165 if original_is_number:
166 try:
167 res = res[0]
168 except (KeyError, TypeError):
169 logger.warning("Based on inputs, a number should be output\
170 but this could not be accessed from what was computed")
171 return res
173 def cdf(self, val):
174 """ Return the CDF of the prior.
175 This calls to the slab CDF and adds a discrete step
176 at the spike location.
178 Parameters
179 ==========
180 val: Union[float, int, array_like]
182 Returns
183 =======
184 array_like: CDF value of val
186 """
187 res = self.slab.cdf(val) * self.slab_fraction
188 res = np.atleast_1d(res)
189 indices_above_spike = np.where(val > self.spike_location)[0]
190 res[indices_above_spike] += self.spike_height
191 return res