Coverage for bilby/core/prior/slabspike.py: 90%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1from numbers import Number 

2import numpy as np 

3 

4from .base import Prior 

5from ..utils import logger 

6 

7 

8class SlabSpikePrior(Prior): 

9 

10 def __init__(self, slab, spike_location=None, spike_height=0): 

11 """'Slab-and-spike' prior, see e.g. https://arxiv.org/abs/1812.07259 

12 This prior is composed of a `slab`, i.e. any common prior distribution, 

13 and a Dirac spike at a fixed location. This can effectively be used 

14 to emulate sampling in the number of dimensions (similar to reversible- 

15 jump MCMC). 

16 

17 `SymmetricLogUniform` and `FermiDirac` are currently not supported. 

18 

19 Parameters 

20 ========== 

21 slab: Prior 

22 Any instance of a bilby prior class. All general prior attributes 

23 from the slab are copied into the SlabSpikePrior. 

24 Note that this hasn't been tested for conditional priors. 

25 spike_location: float, optional 

26 Location of the Dirac spike. Must be between minimum and maximum 

27 of the slab. Defaults to the minimum of the slab 

28 spike_height: float, optional 

29 Relative weight of the spike compared to the slab. Must be 

30 between 0 and 1. Defaults to 0, i.e. the prior is just the slab. 

31 

32 """ 

33 self.slab = slab 

34 super().__init__(name=self.slab.name, latex_label=self.slab.latex_label, unit=self.slab.unit, 

35 minimum=self.slab.minimum, maximum=self.slab.maximum, 

36 check_range_nonzero=self.slab.check_range_nonzero, boundary=self.slab.boundary) 

37 self.spike_location = spike_location 

38 self.spike_height = spike_height 

39 try: 

40 self.inverse_cdf_below_spike = self._find_inverse_cdf_fraction_before_spike() 

41 except Exception as e: 

42 logger.warning("Disregard the following warning when running tests:\n {}".format(e)) 

43 

44 @property 

45 def spike_location(self): 

46 return self._spike_loc 

47 

48 @spike_location.setter 

49 def spike_location(self, spike_loc): 

50 if spike_loc is None: 

51 spike_loc = self.minimum 

52 if not self.minimum <= spike_loc <= self.maximum: 

53 raise ValueError("Spike location {} not within prior domain ".format(spike_loc)) 

54 self._spike_loc = spike_loc 

55 

56 @property 

57 def spike_height(self): 

58 return self._spike_height 

59 

60 @spike_height.setter 

61 def spike_height(self, spike_height): 

62 if 0 <= spike_height <= 1: 

63 self._spike_height = spike_height 

64 else: 

65 raise ValueError("Spike height must be between 0 and 1, but is {}".format(spike_height)) 

66 

67 @property 

68 def slab_fraction(self): 

69 """ Relative prior weight of the slab. """ 

70 return 1 - self.spike_height 

71 

72 def _find_inverse_cdf_fraction_before_spike(self): 

73 return float(self.slab.cdf(self.spike_location)) * self.slab_fraction 

74 

75 def rescale(self, val): 

76 """ 

77 'Rescale' a sample from the unit line element to the prior. 

78 

79 Parameters 

80 ========== 

81 val: Union[float, int, array_like] 

82 A random number between 0 and 1 

83 

84 Returns 

85 ======= 

86 array_like: Associated prior value with input value. 

87 """ 

88 original_is_number = isinstance(val, Number) 

89 val = np.atleast_1d(val) 

90 

91 lower_indices = np.where(val < self.inverse_cdf_below_spike)[0] 

92 intermediate_indices = np.where(np.logical_and( 

93 self.inverse_cdf_below_spike <= val, 

94 val <= self.inverse_cdf_below_spike + self.spike_height))[0] 

95 higher_indices = np.where(val > self.inverse_cdf_below_spike + self.spike_height)[0] 

96 

97 res = np.zeros(len(val)) 

98 res[lower_indices] = self._contracted_rescale(val[lower_indices]) 

99 res[intermediate_indices] = self.spike_location 

100 res[higher_indices] = self._contracted_rescale(val[higher_indices] - self.spike_height) 

101 if original_is_number: 

102 try: 

103 res = res[0] 

104 except (KeyError, TypeError): 

105 logger.warning("Based on inputs, a number should be output\ 

106 but this could not be accessed from what was computed") 

107 return res 

108 

109 def _contracted_rescale(self, val): 

110 """ 

111 Contracted version of the rescale function that implements the `rescale` function 

112 on the pure slab part of the prior. 

113 

114 Parameters 

115 ========== 

116 val: Union[float, int, array_like] 

117 A random number between 0 and self.slab_fraction 

118 

119 Returns 

120 ======= 

121 array_like: Associated prior value with input value. 

122 """ 

123 return self.slab.rescale(val / self.slab_fraction) 

124 

125 def prob(self, val): 

126 """Return the prior probability of val. 

127 Returns np.inf for the spike location 

128 

129 Parameters 

130 ========== 

131 val: Union[float, int, array_like] 

132 

133 Returns 

134 ======= 

135 array_like: Prior probability of val 

136 """ 

137 original_is_number = isinstance(val, Number) 

138 res = self.slab.prob(val) * self.slab_fraction 

139 res = np.atleast_1d(res) 

140 res[np.where(val == self.spike_location)] = np.inf 

141 if original_is_number: 

142 try: 

143 res = res[0] 

144 except (KeyError, TypeError): 

145 logger.warning("Based on inputs, a number should be output\ 

146 but this could not be accessed from what was computed") 

147 return res 

148 

149 def ln_prob(self, val): 

150 """Return the Log prior probability of val. 

151 Returns np.inf for the spike location 

152 

153 Parameters 

154 ========== 

155 val: Union[float, int, array_like] 

156 

157 Returns 

158 ======= 

159 array_like: Prior probability of val 

160 """ 

161 original_is_number = isinstance(val, Number) 

162 res = self.slab.ln_prob(val) + np.log(self.slab_fraction) 

163 res = np.atleast_1d(res) 

164 res[np.where(val == self.spike_location)] = np.inf 

165 if original_is_number: 

166 try: 

167 res = res[0] 

168 except (KeyError, TypeError): 

169 logger.warning("Based on inputs, a number should be output\ 

170 but this could not be accessed from what was computed") 

171 return res 

172 

173 def cdf(self, val): 

174 """ Return the CDF of the prior. 

175 This calls to the slab CDF and adds a discrete step 

176 at the spike location. 

177 

178 Parameters 

179 ========== 

180 val: Union[float, int, array_like] 

181 

182 Returns 

183 ======= 

184 array_like: CDF value of val 

185 

186 """ 

187 res = self.slab.cdf(val) * self.slab_fraction 

188 res = np.atleast_1d(res) 

189 indices_above_spike = np.where(val > self.spike_location)[0] 

190 res[indices_above_spike] += self.spike_height 

191 return res