Coverage for bilby/gw/waveform_generator.py: 97%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import numpy as np 

2 

3from ..core import utils 

4from ..core.series import CoupledTimeAndFrequencySeries 

5from ..core.utils import PropertyAccessor 

6from .conversion import convert_to_lal_binary_black_hole_parameters 

7from .utils import lalsim_GetApproximantFromString 

8 

9 

10class WaveformGenerator(object): 

11 """ 

12 The base waveform generator class. 

13 

14 Waveform generators provide a unified method to call disparate source models. 

15 """ 

16 

17 duration = PropertyAccessor('_times_and_frequencies', 'duration') 

18 sampling_frequency = PropertyAccessor('_times_and_frequencies', 'sampling_frequency') 

19 start_time = PropertyAccessor('_times_and_frequencies', 'start_time') 

20 frequency_array = PropertyAccessor('_times_and_frequencies', 'frequency_array') 

21 time_array = PropertyAccessor('_times_and_frequencies', 'time_array') 

22 

23 def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None, 

24 time_domain_source_model=None, parameters=None, 

25 parameter_conversion=None, 

26 waveform_arguments=None): 

27 """ 

28 The base waveform generator class. 

29 

30 Parameters 

31 ========== 

32 sampling_frequency: float, optional 

33 The sampling frequency 

34 duration: float, optional 

35 Time duration of data 

36 start_time: float, optional 

37 Starting time of the time array 

38 frequency_domain_source_model: func, optional 

39 A python function taking some arguments and returning the frequency 

40 domain strain. Note the first argument must be the frequencies at 

41 which to compute the strain 

42 time_domain_source_model: func, optional 

43 A python function taking some arguments and returning the time 

44 domain strain. Note the first argument must be the times at 

45 which to compute the strain 

46 parameters: dict, optional 

47 Initial values for the parameters 

48 parameter_conversion: func, optional 

49 Function to convert from sampled parameters to parameters of the 

50 waveform generator. Default value is the identity, i.e. it leaves 

51 the parameters unaffected. 

52 waveform_arguments: dict, optional 

53 A dictionary of fixed keyword arguments to pass to either 

54 `frequency_domain_source_model` or `time_domain_source_model`. 

55 

56 Note: the arguments of frequency_domain_source_model (except the first, 

57 which is the frequencies at which to compute the strain) will be added to 

58 the WaveformGenerator object and initialised to `None`. 

59 

60 """ 

61 self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration, 

62 sampling_frequency=sampling_frequency, 

63 start_time=start_time) 

64 self.frequency_domain_source_model = frequency_domain_source_model 

65 self.time_domain_source_model = time_domain_source_model 

66 self.source_parameter_keys = self.__parameters_from_source_model() 

67 if parameter_conversion is None: 

68 self.parameter_conversion = convert_to_lal_binary_black_hole_parameters 

69 else: 

70 self.parameter_conversion = parameter_conversion 

71 if waveform_arguments is not None: 

72 self.waveform_arguments = waveform_arguments 

73 else: 

74 self.waveform_arguments = dict() 

75 if isinstance(parameters, dict): 

76 self.parameters = parameters 

77 self._cache = dict(parameters=None, waveform=None, model=None) 

78 utils.logger.info( 

79 "Waveform generator initiated with\n" 

80 " frequency_domain_source_model: {}\n" 

81 " time_domain_source_model: {}\n" 

82 " parameter_conversion: {}" 

83 .format(utils.get_function_path(self.frequency_domain_source_model), 

84 utils.get_function_path(self.time_domain_source_model), 

85 utils.get_function_path(self.parameter_conversion)) 

86 ) 

87 

88 def __repr__(self): 

89 if self.frequency_domain_source_model is not None: 

90 fdsm_name = self.frequency_domain_source_model.__name__ 

91 else: 

92 fdsm_name = None 

93 if self.time_domain_source_model is not None: 

94 tdsm_name = self.time_domain_source_model.__name__ 

95 else: 

96 tdsm_name = None 

97 if self.parameter_conversion is None: 

98 param_conv_name = None 

99 else: 

100 param_conv_name = self.parameter_conversion.__name__ 

101 

102 return self.__class__.__name__ + '(duration={}, sampling_frequency={}, start_time={}, ' \ 

103 'frequency_domain_source_model={}, time_domain_source_model={}, ' \ 

104 'parameter_conversion={}, ' \ 

105 'waveform_arguments={})'\ 

106 .format(self.duration, self.sampling_frequency, self.start_time, fdsm_name, tdsm_name, 

107 param_conv_name, self.waveform_arguments) 

108 

109 def frequency_domain_strain(self, parameters=None): 

110 """ Wrapper to source_model. 

111 

112 Converts self.parameters with self.parameter_conversion before handing it off to the source model. 

113 Automatically refers to the time_domain_source model via NFFT if no frequency_domain_source_model is given. 

114 

115 Parameters 

116 ========== 

117 parameters: dict, optional 

118 Parameters to evaluate the waveform for, this overwrites 

119 `self.parameters`. 

120 If not provided will fall back to `self.parameters`. 

121 

122 Returns 

123 ======= 

124 array_like: The frequency domain strain for the given set of parameters 

125 

126 Raises 

127 ====== 

128 RuntimeError: If no source model is given 

129 

130 """ 

131 return self._calculate_strain(model=self.frequency_domain_source_model, 

132 model_data_points=self.frequency_array, 

133 parameters=parameters, 

134 transformation_function=utils.nfft, 

135 transformed_model=self.time_domain_source_model, 

136 transformed_model_data_points=self.time_array) 

137 

138 def time_domain_strain(self, parameters=None): 

139 """ Wrapper to source_model. 

140 

141 Converts self.parameters with self.parameter_conversion before handing it off to the source model. 

142 Automatically refers to the frequency_domain_source model via INFFT if no frequency_domain_source_model is 

143 given. 

144 

145 Parameters 

146 ========== 

147 parameters: dict, optional 

148 Parameters to evaluate the waveform for, this overwrites 

149 `self.parameters`. 

150 If not provided will fall back to `self.parameters`. 

151 

152 Returns 

153 ======= 

154 array_like: The time domain strain for the given set of parameters 

155 

156 Raises 

157 ====== 

158 RuntimeError: If no source model is given 

159 

160 """ 

161 return self._calculate_strain(model=self.time_domain_source_model, 

162 model_data_points=self.time_array, 

163 parameters=parameters, 

164 transformation_function=utils.infft, 

165 transformed_model=self.frequency_domain_source_model, 

166 transformed_model_data_points=self.frequency_array) 

167 

168 def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model, 

169 transformed_model_data_points, parameters): 

170 if parameters is not None: 

171 self.parameters = parameters 

172 if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \ 

173 self._cache['transformed_model'] == transformed_model: 

174 return self._cache['waveform'] 

175 if model is not None: 

176 model_strain = self._strain_from_model(model_data_points, model) 

177 elif transformed_model is not None: 

178 model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model, 

179 transformation_function) 

180 else: 

181 raise RuntimeError("No source model given") 

182 self._cache['waveform'] = model_strain 

183 self._cache['parameters'] = self.parameters.copy() 

184 self._cache['model'] = model 

185 self._cache['transformed_model'] = transformed_model 

186 return model_strain 

187 

188 def _strain_from_model(self, model_data_points, model): 

189 return model(model_data_points, **self.parameters) 

190 

191 def _strain_from_transformed_model(self, transformed_model_data_points, transformed_model, transformation_function): 

192 transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model) 

193 

194 if isinstance(transformed_model_strain, np.ndarray): 

195 return transformation_function(transformed_model_strain, self.sampling_frequency) 

196 

197 model_strain = dict() 

198 for key in transformed_model_strain: 

199 if transformation_function == utils.nfft: 

200 model_strain[key], _ = \ 

201 transformation_function(transformed_model_strain[key], self.sampling_frequency) 

202 else: 

203 model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency) 

204 return model_strain 

205 

206 @property 

207 def parameters(self): 

208 """ The dictionary of parameters for source model. 

209 

210 Returns 

211 ======= 

212 dict: The dictionary of parameter key-value pairs 

213 

214 """ 

215 return self.__parameters 

216 

217 @parameters.setter 

218 def parameters(self, parameters): 

219 """ 

220 Set parameters, this applies the conversion function and then removes 

221 any parameters which aren't required by the source function. 

222 

223 (set.symmetric_difference is the opposite of set.intersection) 

224 

225 Parameters 

226 ========== 

227 parameters: dict 

228 Input parameter dictionary, this is copied, passed to the conversion 

229 function and has self.waveform_arguments added to it. 

230 """ 

231 if not isinstance(parameters, dict): 

232 raise TypeError('"parameters" must be a dictionary.') 

233 new_parameters = parameters.copy() 

234 new_parameters, _ = self.parameter_conversion(new_parameters) 

235 for key in self.source_parameter_keys.symmetric_difference( 

236 new_parameters): 

237 new_parameters.pop(key) 

238 self.__parameters = new_parameters 

239 self.__parameters.update(self.waveform_arguments) 

240 

241 def __parameters_from_source_model(self): 

242 """ 

243 Infer the named arguments of the source model. 

244 

245 Returns 

246 ======= 

247 set: The names of the arguments of the source model. 

248 """ 

249 if self.frequency_domain_source_model is not None: 

250 model = self.frequency_domain_source_model 

251 elif self.time_domain_source_model is not None: 

252 model = self.time_domain_source_model 

253 else: 

254 raise AttributeError('Either time or frequency domain source ' 

255 'model must be provided.') 

256 return set(utils.infer_parameters_from_function(model)) 

257 

258 

259class LALCBCWaveformGenerator(WaveformGenerator): 

260 """ A waveform generator with specific checks for LAL CBC waveforms """ 

261 LAL_SIM_INSPIRAL_SPINS_FLOW = 1 

262 

263 def __init__(self, **kwargs): 

264 super().__init__(**kwargs) 

265 self.validate_reference_frequency() 

266 

267 def validate_reference_frequency(self): 

268 from lalsimulation import SimInspiralGetSpinFreqFromApproximant 

269 waveform_approximant = self.waveform_arguments["waveform_approximant"] 

270 waveform_approximant_number = lalsim_GetApproximantFromString(waveform_approximant) 

271 if SimInspiralGetSpinFreqFromApproximant(waveform_approximant_number) == self.LAL_SIM_INSPIRAL_SPINS_FLOW: 

272 if self.waveform_arguments["reference_frequency"] != self.waveform_arguments["minimum_frequency"]: 

273 raise ValueError(f"For {waveform_approximant}, reference_frequency must equal minimum_frequency")