Coverage for bilby/core/prior/interpolated.py: 87%

86 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import numpy as np 

2from scipy.interpolate import interp1d 

3 

4from .base import Prior 

5from ..utils import logger 

6 

7 

8class Interped(Prior): 

9 

10 def __init__(self, xx, yy, minimum=np.nan, maximum=np.nan, name=None, 

11 latex_label=None, unit=None, boundary=None): 

12 """Creates an interpolated prior function from arrays of xx and yy=p(xx) 

13 

14 Parameters 

15 ========== 

16 xx: array_like 

17 x values for the to be interpolated prior function 

18 yy: array_like 

19 p(xx) values for the to be interpolated prior function 

20 minimum: float 

21 See superclass 

22 maximum: float 

23 See superclass 

24 name: str 

25 See superclass 

26 latex_label: str 

27 See superclass 

28 unit: str 

29 See superclass 

30 boundary: str 

31 See superclass 

32 

33 Attributes 

34 ========== 

35 probability_density: scipy.interpolate.interp1d 

36 Interpolated prior probability distribution 

37 cumulative_distribution: scipy.interpolate.interp1d 

38 Interpolated cumulative prior probability distribution 

39 inverse_cumulative_distribution: scipy.interpolate.interp1d 

40 Inverted cumulative prior probability distribution 

41 YY: array_like 

42 Cumulative prior probability distribution 

43 

44 """ 

45 self.xx = xx 

46 self.min_limit = min(xx) 

47 self.max_limit = max(xx) 

48 self._yy = yy 

49 self.YY = None 

50 self.probability_density = None 

51 self.cumulative_distribution = None 

52 self.inverse_cumulative_distribution = None 

53 self.__all_interpolated = interp1d(x=xx, y=yy, bounds_error=False, fill_value=0) 

54 minimum = float(np.nanmax(np.array((min(xx), minimum)))) 

55 maximum = float(np.nanmin(np.array((max(xx), maximum)))) 

56 super(Interped, self).__init__(name=name, latex_label=latex_label, unit=unit, 

57 minimum=minimum, maximum=maximum, boundary=boundary) 

58 self._update_instance() 

59 

60 def __eq__(self, other): 

61 if self.__class__ != other.__class__: 

62 return False 

63 if np.array_equal(self.xx, other.xx) and np.array_equal(self.yy, other.yy): 

64 return True 

65 return False 

66 

67 def prob(self, val): 

68 """Return the prior probability of val. 

69 

70 Parameters 

71 ========== 

72 val: Union[float, int, array_like] 

73 

74 Returns 

75 ======= 

76 Union[float, array_like]: Prior probability of val 

77 """ 

78 return self.probability_density(val) 

79 

80 def cdf(self, val): 

81 return self.cumulative_distribution(val) 

82 

83 def rescale(self, val): 

84 """ 

85 'Rescale' a sample from the unit line element to the prior. 

86 

87 This maps to the inverse CDF. This is done using interpolation. 

88 """ 

89 rescaled = self.inverse_cumulative_distribution(val) 

90 if rescaled.shape == (): 

91 rescaled = float(rescaled) 

92 return rescaled 

93 

94 @property 

95 def minimum(self): 

96 """Return minimum of the prior distribution. 

97 

98 Updates the prior distribution if minimum is set to a different value. 

99 

100 Yields an error if value is set below instantiated x-array minimum. 

101 

102 Returns 

103 ======= 

104 float: Minimum of the prior distribution 

105 

106 """ 

107 return self._minimum 

108 

109 @minimum.setter 

110 def minimum(self, minimum): 

111 if minimum < self.min_limit: 

112 raise ValueError('Minimum cannot be set below {}.'.format(round(self.min_limit, 2))) 

113 self._minimum = minimum 

114 if '_maximum' in self.__dict__ and self._maximum < np.inf: 

115 self._update_instance() 

116 

117 @property 

118 def maximum(self): 

119 """Return maximum of the prior distribution. 

120 

121 Updates the prior distribution if maximum is set to a different value. 

122 

123 Yields an error if value is set above instantiated x-array maximum. 

124 

125 Returns 

126 ======= 

127 float: Maximum of the prior distribution 

128 

129 """ 

130 return self._maximum 

131 

132 @maximum.setter 

133 def maximum(self, maximum): 

134 if maximum > self.max_limit: 

135 raise ValueError('Maximum cannot be set above {}.'.format(round(self.max_limit, 2))) 

136 self._maximum = maximum 

137 if '_minimum' in self.__dict__ and self._minimum < np.inf: 

138 self._update_instance() 

139 

140 @property 

141 def yy(self): 

142 """Return p(xx) values of the interpolated prior function. 

143 

144 Updates the prior distribution if it is changed 

145 

146 Returns 

147 ======= 

148 array_like: p(xx) values 

149 

150 """ 

151 return self._yy 

152 

153 @yy.setter 

154 def yy(self, yy): 

155 self._yy = yy 

156 self.__all_interpolated = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0) 

157 self._update_instance() 

158 

159 def _update_instance(self): 

160 self.xx = np.linspace(self.minimum, self.maximum, len(self.xx)) 

161 self._yy = self.__all_interpolated(self.xx) 

162 self._initialize_attributes() 

163 

164 def _initialize_attributes(self): 

165 from scipy.integrate import cumulative_trapezoid 

166 if np.trapz(self._yy, self.xx) != 1: 

167 logger.debug('Supplied PDF for {} is not normalised, normalising.'.format(self.name)) 

168 self._yy /= np.trapz(self._yy, self.xx) 

169 self.YY = cumulative_trapezoid(self._yy, self.xx, initial=0) 

170 # Need last element of cumulative distribution to be exactly one. 

171 self.YY[-1] = 1 

172 self.probability_density = interp1d(x=self.xx, y=self._yy, bounds_error=False, fill_value=0) 

173 self.cumulative_distribution = interp1d(x=self.xx, y=self.YY, bounds_error=False, fill_value=(0, 1)) 

174 self.inverse_cumulative_distribution = interp1d(x=self.YY, y=self.xx, bounds_error=True) 

175 

176 

177class FromFile(Interped): 

178 

179 def __init__(self, file_name, minimum=None, maximum=None, name=None, 

180 latex_label=None, unit=None, boundary=None): 

181 """Creates an interpolated prior function from arrays of xx and yy=p(xx) extracted from a file 

182 

183 Parameters 

184 ========== 

185 file_name: str 

186 Name of the file containing the xx and yy arrays 

187 minimum: float 

188 See superclass 

189 maximum: float 

190 See superclass 

191 name: str 

192 See superclass 

193 latex_label: str 

194 See superclass 

195 unit: str 

196 See superclass 

197 boundary: str 

198 See superclass 

199 

200 """ 

201 try: 

202 self.file_name = file_name 

203 xx, yy = np.genfromtxt(self.file_name).T 

204 super(FromFile, self).__init__(xx=xx, yy=yy, minimum=minimum, 

205 maximum=maximum, name=name, latex_label=latex_label, 

206 unit=unit, boundary=boundary) 

207 except IOError: 

208 logger.warning("Can't load {}.".format(self.file_name)) 

209 logger.warning("Format should be:") 

210 logger.warning(r"x\tp(x)")