Coverage for bilby/gw/waveform_generator.py: 97%
101 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import numpy as np
3from ..core import utils
4from ..core.series import CoupledTimeAndFrequencySeries
5from ..core.utils import PropertyAccessor
6from .conversion import convert_to_lal_binary_black_hole_parameters
7from .utils import lalsim_GetApproximantFromString
10class WaveformGenerator(object):
11 """
12 The base waveform generator class.
14 Waveform generators provide a unified method to call disparate source models.
15 """
17 duration = PropertyAccessor('_times_and_frequencies', 'duration')
18 sampling_frequency = PropertyAccessor('_times_and_frequencies', 'sampling_frequency')
19 start_time = PropertyAccessor('_times_and_frequencies', 'start_time')
20 frequency_array = PropertyAccessor('_times_and_frequencies', 'frequency_array')
21 time_array = PropertyAccessor('_times_and_frequencies', 'time_array')
23 def __init__(self, duration=None, sampling_frequency=None, start_time=0, frequency_domain_source_model=None,
24 time_domain_source_model=None, parameters=None,
25 parameter_conversion=None,
26 waveform_arguments=None):
27 """
28 The base waveform generator class.
30 Parameters
31 ==========
32 sampling_frequency: float, optional
33 The sampling frequency
34 duration: float, optional
35 Time duration of data
36 start_time: float, optional
37 Starting time of the time array
38 frequency_domain_source_model: func, optional
39 A python function taking some arguments and returning the frequency
40 domain strain. Note the first argument must be the frequencies at
41 which to compute the strain
42 time_domain_source_model: func, optional
43 A python function taking some arguments and returning the time
44 domain strain. Note the first argument must be the times at
45 which to compute the strain
46 parameters: dict, optional
47 Initial values for the parameters
48 parameter_conversion: func, optional
49 Function to convert from sampled parameters to parameters of the
50 waveform generator. Default value is the identity, i.e. it leaves
51 the parameters unaffected.
52 waveform_arguments: dict, optional
53 A dictionary of fixed keyword arguments to pass to either
54 `frequency_domain_source_model` or `time_domain_source_model`.
56 Note: the arguments of frequency_domain_source_model (except the first,
57 which is the frequencies at which to compute the strain) will be added to
58 the WaveformGenerator object and initialised to `None`.
60 """
61 self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration,
62 sampling_frequency=sampling_frequency,
63 start_time=start_time)
64 self.frequency_domain_source_model = frequency_domain_source_model
65 self.time_domain_source_model = time_domain_source_model
66 self.source_parameter_keys = self.__parameters_from_source_model()
67 if parameter_conversion is None:
68 self.parameter_conversion = convert_to_lal_binary_black_hole_parameters
69 else:
70 self.parameter_conversion = parameter_conversion
71 if waveform_arguments is not None:
72 self.waveform_arguments = waveform_arguments
73 else:
74 self.waveform_arguments = dict()
75 if isinstance(parameters, dict):
76 self.parameters = parameters
77 self._cache = dict(parameters=None, waveform=None, model=None)
78 utils.logger.info(
79 "Waveform generator initiated with\n"
80 " frequency_domain_source_model: {}\n"
81 " time_domain_source_model: {}\n"
82 " parameter_conversion: {}"
83 .format(utils.get_function_path(self.frequency_domain_source_model),
84 utils.get_function_path(self.time_domain_source_model),
85 utils.get_function_path(self.parameter_conversion))
86 )
88 def __repr__(self):
89 if self.frequency_domain_source_model is not None:
90 fdsm_name = self.frequency_domain_source_model.__name__
91 else:
92 fdsm_name = None
93 if self.time_domain_source_model is not None:
94 tdsm_name = self.time_domain_source_model.__name__
95 else:
96 tdsm_name = None
97 if self.parameter_conversion is None:
98 param_conv_name = None
99 else:
100 param_conv_name = self.parameter_conversion.__name__
102 return self.__class__.__name__ + '(duration={}, sampling_frequency={}, start_time={}, ' \
103 'frequency_domain_source_model={}, time_domain_source_model={}, ' \
104 'parameter_conversion={}, ' \
105 'waveform_arguments={})'\
106 .format(self.duration, self.sampling_frequency, self.start_time, fdsm_name, tdsm_name,
107 param_conv_name, self.waveform_arguments)
109 def frequency_domain_strain(self, parameters=None):
110 """ Wrapper to source_model.
112 Converts self.parameters with self.parameter_conversion before handing it off to the source model.
113 Automatically refers to the time_domain_source model via NFFT if no frequency_domain_source_model is given.
115 Parameters
116 ==========
117 parameters: dict, optional
118 Parameters to evaluate the waveform for, this overwrites
119 `self.parameters`.
120 If not provided will fall back to `self.parameters`.
122 Returns
123 =======
124 array_like: The frequency domain strain for the given set of parameters
126 Raises
127 ======
128 RuntimeError: If no source model is given
130 """
131 return self._calculate_strain(model=self.frequency_domain_source_model,
132 model_data_points=self.frequency_array,
133 parameters=parameters,
134 transformation_function=utils.nfft,
135 transformed_model=self.time_domain_source_model,
136 transformed_model_data_points=self.time_array)
138 def time_domain_strain(self, parameters=None):
139 """ Wrapper to source_model.
141 Converts self.parameters with self.parameter_conversion before handing it off to the source model.
142 Automatically refers to the frequency_domain_source model via INFFT if no frequency_domain_source_model is
143 given.
145 Parameters
146 ==========
147 parameters: dict, optional
148 Parameters to evaluate the waveform for, this overwrites
149 `self.parameters`.
150 If not provided will fall back to `self.parameters`.
152 Returns
153 =======
154 array_like: The time domain strain for the given set of parameters
156 Raises
157 ======
158 RuntimeError: If no source model is given
160 """
161 return self._calculate_strain(model=self.time_domain_source_model,
162 model_data_points=self.time_array,
163 parameters=parameters,
164 transformation_function=utils.infft,
165 transformed_model=self.frequency_domain_source_model,
166 transformed_model_data_points=self.frequency_array)
168 def _calculate_strain(self, model, model_data_points, transformation_function, transformed_model,
169 transformed_model_data_points, parameters):
170 if parameters is not None:
171 self.parameters = parameters
172 if self.parameters == self._cache['parameters'] and self._cache['model'] == model and \
173 self._cache['transformed_model'] == transformed_model:
174 return self._cache['waveform']
175 if model is not None:
176 model_strain = self._strain_from_model(model_data_points, model)
177 elif transformed_model is not None:
178 model_strain = self._strain_from_transformed_model(transformed_model_data_points, transformed_model,
179 transformation_function)
180 else:
181 raise RuntimeError("No source model given")
182 self._cache['waveform'] = model_strain
183 self._cache['parameters'] = self.parameters.copy()
184 self._cache['model'] = model
185 self._cache['transformed_model'] = transformed_model
186 return model_strain
188 def _strain_from_model(self, model_data_points, model):
189 return model(model_data_points, **self.parameters)
191 def _strain_from_transformed_model(self, transformed_model_data_points, transformed_model, transformation_function):
192 transformed_model_strain = self._strain_from_model(transformed_model_data_points, transformed_model)
194 if isinstance(transformed_model_strain, np.ndarray):
195 return transformation_function(transformed_model_strain, self.sampling_frequency)
197 model_strain = dict()
198 for key in transformed_model_strain:
199 if transformation_function == utils.nfft:
200 model_strain[key], _ = \
201 transformation_function(transformed_model_strain[key], self.sampling_frequency)
202 else:
203 model_strain[key] = transformation_function(transformed_model_strain[key], self.sampling_frequency)
204 return model_strain
206 @property
207 def parameters(self):
208 """ The dictionary of parameters for source model.
210 Returns
211 =======
212 dict: The dictionary of parameter key-value pairs
214 """
215 return self.__parameters
217 @parameters.setter
218 def parameters(self, parameters):
219 """
220 Set parameters, this applies the conversion function and then removes
221 any parameters which aren't required by the source function.
223 (set.symmetric_difference is the opposite of set.intersection)
225 Parameters
226 ==========
227 parameters: dict
228 Input parameter dictionary, this is copied, passed to the conversion
229 function and has self.waveform_arguments added to it.
230 """
231 if not isinstance(parameters, dict):
232 raise TypeError('"parameters" must be a dictionary.')
233 new_parameters = parameters.copy()
234 new_parameters, _ = self.parameter_conversion(new_parameters)
235 for key in self.source_parameter_keys.symmetric_difference(
236 new_parameters):
237 new_parameters.pop(key)
238 self.__parameters = new_parameters
239 self.__parameters.update(self.waveform_arguments)
241 def __parameters_from_source_model(self):
242 """
243 Infer the named arguments of the source model.
245 Returns
246 =======
247 set: The names of the arguments of the source model.
248 """
249 if self.frequency_domain_source_model is not None:
250 model = self.frequency_domain_source_model
251 elif self.time_domain_source_model is not None:
252 model = self.time_domain_source_model
253 else:
254 raise AttributeError('Either time or frequency domain source '
255 'model must be provided.')
256 return set(utils.infer_parameters_from_function(model))
259class LALCBCWaveformGenerator(WaveformGenerator):
260 """ A waveform generator with specific checks for LAL CBC waveforms """
261 LAL_SIM_INSPIRAL_SPINS_FLOW = 1
263 def __init__(self, **kwargs):
264 super().__init__(**kwargs)
265 self.validate_reference_frequency()
267 def validate_reference_frequency(self):
268 from lalsimulation import SimInspiralGetSpinFreqFromApproximant
269 waveform_approximant = self.waveform_arguments["waveform_approximant"]
270 waveform_approximant_number = lalsim_GetApproximantFromString(waveform_approximant)
271 if SimInspiralGetSpinFreqFromApproximant(waveform_approximant_number) == self.LAL_SIM_INSPIRAL_SPINS_FLOW:
272 if self.waveform_arguments["reference_frequency"] != self.waveform_arguments["minimum_frequency"]:
273 raise ValueError(f"For {waveform_approximant}, reference_frequency must equal minimum_frequency")