Coverage for bilby/core/prior/interpolated.py: 87%
86 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import numpy as np
2from scipy.interpolate import interp1d
4from .base import Prior
5from ..utils import logger
8class Interped(Prior):
10 def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None,
11 latex_label=None, unit=None, boundary=None):
12 """Creates an interpolated prior function from arrays of xx and yy=p(xx)
14 Parameters
15 ==========
16 xx: array_like
17 x values for the to be interpolated prior function
18 yy: array_like
19 p(xx) values for the to be interpolated prior function
20 minimum: float
21 See superclass
22 maximum: float
23 See superclass
24 name: str
25 See superclass
26 latex_label: str
27 See superclass
28 unit: str
29 See superclass
30 boundary: str
31 See superclass
33 Attributes
34 ==========
35 probability_density: scipy.interpolate.interp1d
36 Interpolated prior probability distribution
37 cumulative_distribution: scipy.interpolate.interp1d
38 Interpolated cumulative prior probability distribution
39 inverse_cumulative_distribution: scipy.interpolate.interp1d
40 Inverted cumulative prior probability distribution
41 YY: array_like
42 Cumulative prior probability distribution
44 """
45 self.xx = xx
46 self.min_limit = min(xx)
47 self.max_limit = max(xx)
48 self._yy = yy
49 self.YY = None
50 self.probability_density = None
51 self.cumulative_distribution = None
52 self.inverse_cumulative_distribution = None
53 self.__all_interpolated = interp1d(x=xx, y=yy, bounds_error=False, fill_value=0)
54 minimum = float(np.nanmax(np.array((min(xx), minimum))))
55 maximum = float(np.nanmin(np.array((max(xx), maximum))))
56 super(Interped, self).__init__(name=name, latex_label=latex_label, unit=unit,
57 minimum=minimum, maximum=maximum, boundary=boundary)
58 self._update_instance()
60 def __eq__(self, other):
61 if self.__class__ != other.__class__:
62 return False
63 if np.array_equal(self.xx, other.xx) and np.array_equal(self.yy, other.yy):
64 return True
65 return False
67 def prob(self, val):
68 """Return the prior probability of val.
70 Parameters
71 ==========
72 val: Union[float, int, array_like]
74 Returns
75 =======
76 Union[float, array_like]: Prior probability of val
77 """
78 return self.probability_density(val)
80 def cdf(self, val):
81 return self.cumulative_distribution(val)
83 def rescale(self, val):
84 """
85 'Rescale' a sample from the unit line element to the prior.
87 This maps to the inverse CDF. This is done using interpolation.
88 """
89 rescaled = self.inverse_cumulative_distribution(val)
90 if rescaled.shape == ():
91 rescaled = float(rescaled)
92 return rescaled
94 @property
95 def minimum(self):
96 """Return minimum of the prior distribution.
98 Updates the prior distribution if minimum is set to a different value.
100 Yields an error if value is set below instantiated x-array minimum.
102 Returns
103 =======
104 float: Minimum of the prior distribution
106 """
107 return self._minimum
109 @minimum.setter
110 def minimum(self, minimum):
111 if minimum < self.min_limit:
112 raise ValueError('Minimum cannot be set below {}.'.format(round(self.min_limit, 2)))
113 self._minimum = minimum
114 if '_maximum' in self.__dict__ and self._maximum < np.inf:
115 self._update_instance()
117 @property
118 def maximum(self):
119 """Return maximum of the prior distribution.
121 Updates the prior distribution if maximum is set to a different value.
123 Yields an error if value is set above instantiated x-array maximum.
125 Returns
126 =======
127 float: Maximum of the prior distribution
129 """
130 return self._maximum
132 @maximum.setter
133 def maximum(self, maximum):
134 if maximum > self.max_limit:
135 raise ValueError('Maximum cannot be set above {}.'.format(round(self.max_limit, 2)))
136 self._maximum = maximum
137 if '_minimum' in self.__dict__ and self._minimum < np.inf:
138 self._update_instance()
140 @property
141 def yy(self):
142 """Return p(xx) values of the interpolated prior function.
144 Updates the prior distribution if it is changed
146 Returns
147 =======
148 array_like: p(xx) values
150 """
151 return self._yy
153 @yy.setter
154 def yy(self, yy):
155 self._yy = yy
156 self.__all_interpolated = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0)
157 self._update_instance()
159 def _update_instance(self):
160 self.xx = np.linspace(self.minimum, self.maximum, len(self.xx))
161 self._yy = self.__all_interpolated(self.xx)
162 self._initialize_attributes()
164 def _initialize_attributes(self):
165 from scipy.integrate import cumulative_trapezoid
166 if np.trapz(self._yy, self.xx) != 1:
167 logger.debug('Supplied PDF for {} is not normalised, normalising.'.format(self.name))
168 self._yy /= np.trapz(self._yy, self.xx)
169 self.YY = cumulative_trapezoid(self._yy, self.xx, initial=0)
170 # Need last element of cumulative distribution to be exactly one.
171 self.YY[-1] = 1
172 self.probability_density = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0)
173 self.cumulative_distribution = interp1d(x=self.xx, y=self.YY, bounds_error=False, fill_value=(0, 1))
174 self.inverse_cumulative_distribution = interp1d(x=self.YY, y=self.xx, bounds_error=True)
177class FromFile(Interped):
179 def __init__(self, file_name, minimum=None, maximum=None, name=None,
180 latex_label=None, unit=None, boundary=None):
181 """Creates an interpolated prior function from arrays of xx and yy=p(xx) extracted from a file
183 Parameters
184 ==========
185 file_name: str
186 Name of the file containing the xx and yy arrays
187 minimum: float
188 See superclass
189 maximum: float
190 See superclass
191 name: str
192 See superclass
193 latex_label: str
194 See superclass
195 unit: str
196 See superclass
197 boundary: str
198 See superclass
200 """
201 try:
202 self.file_name = file_name
203 xx, yy = np.genfromtxt(self.file_name).T
204 super(FromFile, self).__init__(xx=xx, yy=yy, minimum=minimum,
205 maximum=maximum, name=name, latex_label=latex_label,
206 unit=unit, boundary=boundary)
207 except IOError:
208 logger.warning("Can't load {}.".format(self.file_name))
209 logger.warning("Format should be:")
210 logger.warning(r"x\tp(x)")