Coverage for bilby/gw/detector/strain_data.py: 63%
299 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import numpy as np
3from ...core import utils
4from ...core.series import CoupledTimeAndFrequencySeries
5from ...core.utils import logger, PropertyAccessor
6from .. import utils as gwutils
9class InterferometerStrainData(object):
10 """ Strain data for an interferometer """
12 duration = PropertyAccessor('_times_and_frequencies', 'duration')
13 sampling_frequency = PropertyAccessor('_times_and_frequencies', 'sampling_frequency')
14 start_time = PropertyAccessor('_times_and_frequencies', 'start_time')
15 frequency_array = PropertyAccessor('_times_and_frequencies', 'frequency_array')
16 time_array = PropertyAccessor('_times_and_frequencies', 'time_array')
18 def __init__(self, minimum_frequency=0, maximum_frequency=np.inf,
19 roll_off=0.2, notch_list=None):
20 """ Initiate an InterferometerStrainData object
22 The initialised object contains no data, this should be added using one
23 of the `set_from..` methods.
25 Parameters
26 ==========
27 minimum_frequency: float
28 Minimum frequency to analyse for detector. Default is 0.
29 maximum_frequency: float
30 Maximum frequency to analyse for detector. Default is infinity.
31 roll_off: float
32 The roll-off (in seconds) used in the Tukey window, default=0.2s.
33 This corresponds to alpha * duration / 2 for scipy tukey window.
34 notch_list: bilby.gw.detector.strain_data.NotchList
35 A list of notches
37 """
39 self.minimum_frequency = minimum_frequency
40 self.maximum_frequency = maximum_frequency
41 self.notch_list = notch_list
42 self.roll_off = roll_off
43 self.window_factor = 1
45 self._times_and_frequencies = CoupledTimeAndFrequencySeries()
47 self._frequency_mask_updated = False
48 self._frequency_mask = None
49 self._frequency_domain_strain = None
50 self._time_domain_strain = None
51 self._channel = None
53 def __eq__(self, other):
54 if self.minimum_frequency == other.minimum_frequency \
55 and self.maximum_frequency == other.maximum_frequency \
56 and self.roll_off == other.roll_off \
57 and self.window_factor == other.window_factor \
58 and self.sampling_frequency == other.sampling_frequency \
59 and self.duration == other.duration \
60 and self.start_time == other.start_time \
61 and np.array_equal(self.time_array, other.time_array) \
62 and np.array_equal(self.frequency_array, other.frequency_array) \
63 and np.array_equal(self.frequency_domain_strain, other.frequency_domain_strain) \
64 and np.array_equal(self.time_domain_strain, other.time_domain_strain):
65 return True
66 return False
68 def time_within_data(self, time):
69 """ Check if time is within the data span
71 Parameters
72 ==========
73 time: float
74 The time to check
76 Returns
77 =======
78 bool:
79 A boolean stating whether the time is inside or outside the span
81 """
82 if time < self.start_time:
83 logger.debug("Time is before the start_time")
84 return False
85 elif time > self.start_time + self.duration:
86 logger.debug("Time is after the start_time + duration")
87 return False
88 else:
89 return True
91 @property
92 def minimum_frequency(self):
93 return self._minimum_frequency
95 @minimum_frequency.setter
96 def minimum_frequency(self, minimum_frequency):
97 self._minimum_frequency = minimum_frequency
98 self._frequency_mask_updated = False
100 @property
101 def maximum_frequency(self):
102 """ Force the maximum frequency be less than the Nyquist frequency """
103 if self.sampling_frequency is not None:
104 if 2 * self._maximum_frequency > self.sampling_frequency:
105 self._maximum_frequency = self.sampling_frequency / 2.
106 return self._maximum_frequency
108 @maximum_frequency.setter
109 def maximum_frequency(self, maximum_frequency):
110 self._maximum_frequency = maximum_frequency
111 self._frequency_mask_updated = False
113 @property
114 def notch_list(self):
115 return self._notch_list
117 @notch_list.setter
118 def notch_list(self, notch_list):
119 """ Set the notch_list
121 Parameters
122 ==========
123 notch_list: list, bilby.gw.detector.strain_data.NotchList
124 A list of length-2 tuples of the (max, min) frequency for the
125 notches or a pre-made bilby NotchList.
127 """
128 if notch_list is None:
129 self._notch_list = NotchList(None)
130 elif isinstance(notch_list, list):
131 self._notch_list = NotchList(notch_list)
132 elif isinstance(notch_list, NotchList):
133 self._notch_list = notch_list
134 else:
135 raise ValueError("notch_list {} not understood".format(notch_list))
136 self._frequency_mask_updated = False
138 @property
139 def frequency_mask(self):
140 """ Masking array for limiting the frequency band.
142 Returns
143 =======
144 mask: np.ndarray
145 An array of boolean values
146 """
147 if not self._frequency_mask_updated:
148 frequency_array = self._times_and_frequencies.frequency_array
149 mask = ((frequency_array >= self.minimum_frequency) &
150 (frequency_array <= self.maximum_frequency))
151 for notch in self.notch_list:
152 mask[notch.get_idxs(frequency_array)] = False
153 self._frequency_mask = mask
154 self._frequency_mask_updated = True
155 return self._frequency_mask
157 @frequency_mask.setter
158 def frequency_mask(self, mask):
159 self._frequency_mask = mask
160 self._frequency_mask_updated = True
162 @property
163 def alpha(self):
164 return 2 * self.roll_off / self.duration
166 def time_domain_window(self, roll_off=None, alpha=None):
167 """
168 Window function to apply to time domain data before FFTing.
170 This defines self.window_factor as the power loss due to the windowing.
171 See https://dcc.ligo.org/DocDB/0027/T040089/000/T040089-00.pdf
173 Parameters
174 ==========
175 roll_off: float
176 Rise time of window in seconds
177 alpha: float
178 Parameter to pass to tukey window, how much of segment falls
179 into windowed part
181 Returns
182 =======
183 window: array
184 Window function over time array
185 """
186 from scipy.signal.windows import tukey
187 if roll_off is not None:
188 self.roll_off = roll_off
189 elif alpha is not None:
190 self.roll_off = alpha * self.duration / 2
191 window = tukey(len(self._time_domain_strain), alpha=self.alpha)
192 self.window_factor = np.mean(window ** 2)
193 return window
195 @property
196 def time_domain_strain(self):
197 """ The time domain strain, in units of strain """
198 if self._time_domain_strain is not None:
199 return self._time_domain_strain
200 elif self._frequency_domain_strain is not None:
201 self._time_domain_strain = utils.infft(
202 self.frequency_domain_strain, self.sampling_frequency)
203 return self._time_domain_strain
205 else:
206 raise ValueError("time domain strain data not yet set")
208 @property
209 def frequency_domain_strain(self):
210 """ Returns the frequency domain strain
212 This is the frequency domain strain normalised to units of
213 strain / Hz, obtained by a one-sided Fourier transform of the
214 time domain data, divided by the sampling frequency.
215 """
216 if self._frequency_domain_strain is not None:
217 return self._frequency_domain_strain * self.frequency_mask
218 elif self._time_domain_strain is not None:
219 logger.debug("Generating frequency domain strain from given time "
220 "domain strain.")
221 logger.debug("Applying a tukey window with alpha={}, roll off={}".format(
222 self.alpha, self.roll_off))
223 # self.low_pass_filter()
224 window = self.time_domain_window()
225 self._frequency_domain_strain, self.frequency_array = utils.nfft(
226 self._time_domain_strain * window, self.sampling_frequency)
227 return self._frequency_domain_strain * self.frequency_mask
228 else:
229 raise ValueError("frequency domain strain data not yet set")
231 @frequency_domain_strain.setter
232 def frequency_domain_strain(self, frequency_domain_strain):
233 if not len(self.frequency_array) == len(frequency_domain_strain):
234 raise ValueError("The frequency_array and the set strain have different lengths")
235 self._frequency_domain_strain = frequency_domain_strain
236 self._time_domain_strain = None
238 def to_gwpy_timeseries(self):
239 """
240 Output the time series strain data as a :class:`gwpy.timeseries.TimeSeries`.
241 """
242 try:
243 from gwpy.timeseries import TimeSeries
244 except ModuleNotFoundError:
245 raise ModuleNotFoundError("Cannot output strain data as gwpy TimeSeries")
247 return TimeSeries(
248 self.time_domain_strain, sample_rate=self.sampling_frequency,
249 t0=self.start_time, channel=self.channel
250 )
252 def to_pycbc_timeseries(self):
253 """
254 Output the time series strain data as a :class:`pycbc.types.timeseries.TimeSeries`.
255 """
257 try:
258 from pycbc.types.timeseries import TimeSeries
259 from lal import LIGOTimeGPS
260 except ModuleNotFoundError:
261 raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries")
263 return TimeSeries(
264 self.time_domain_strain, delta_t=(1. / self.sampling_frequency),
265 epoch=LIGOTimeGPS(self.start_time)
266 )
268 def to_lal_timeseries(self):
269 """
270 Output the time series strain data as a LAL TimeSeries object.
271 """
272 try:
273 from lal import CreateREAL8TimeSeries, LIGOTimeGPS, SecondUnit
274 except ModuleNotFoundError:
275 raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries")
277 lal_data = CreateREAL8TimeSeries(
278 "", LIGOTimeGPS(self.start_time), 0, 1 / self.sampling_frequency,
279 SecondUnit, len(self.time_domain_strain)
280 )
281 lal_data.data.data[:] = self.time_domain_strain
283 return lal_data
285 def to_gwpy_frequencyseries(self):
286 """
287 Output the frequency series strain data as a :class:`gwpy.frequencyseries.FrequencySeries`.
288 """
289 try:
290 from gwpy.frequencyseries import FrequencySeries
291 except ModuleNotFoundError:
292 raise ModuleNotFoundError("Cannot output strain data as gwpy FrequencySeries")
294 return FrequencySeries(
295 self.frequency_domain_strain,
296 frequencies=self.frequency_array,
297 epoch=self.start_time,
298 channel=self.channel
299 )
301 def to_pycbc_frequencyseries(self):
302 """
303 Output the frequency series strain data as a :class:`pycbc.types.frequencyseries.FrequencySeries`.
304 """
306 try:
307 from pycbc.types.frequencyseries import FrequencySeries
308 from lal import LIGOTimeGPS
309 except ImportError:
310 raise ImportError("Cannot output strain data as PyCBC FrequencySeries")
312 return FrequencySeries(
313 self.frequency_domain_strain,
314 delta_f=1 / self.duration,
315 epoch=LIGOTimeGPS(self.start_time)
316 )
318 def to_lal_frequencyseries(self):
319 """
320 Output the frequency series strain data as a LAL FrequencySeries object.
321 """
322 try:
323 from lal import CreateCOMPLEX16FrequencySeries, LIGOTimeGPS, SecondUnit
324 except ModuleNotFoundError:
325 raise ModuleNotFoundError("Cannot output strain data as PyCBC TimeSeries")
327 lal_data = CreateCOMPLEX16FrequencySeries(
328 "",
329 LIGOTimeGPS(self.start_time),
330 self.frequency_array[0],
331 1 / self.duration,
332 SecondUnit,
333 len(self.frequency_domain_strain)
334 )
335 lal_data.data.data[:] = self.frequency_domain_strain
337 return lal_data
339 def low_pass_filter(self, filter_freq=None):
340 """ Low pass filter the data """
341 from gwpy.signal.filter_design import lowpass
342 from gwpy.timeseries import TimeSeries
344 if filter_freq is None:
345 logger.debug(
346 "Setting low pass filter_freq using given maximum frequency")
347 filter_freq = self.maximum_frequency
349 if 2 * filter_freq >= self.sampling_frequency:
350 logger.info(
351 "Low pass filter frequency of {}Hz requested, this is equal"
352 " or greater than the Nyquist frequency so no filter applied"
353 .format(filter_freq))
354 return
356 logger.debug("Applying low pass filter with filter frequency {}".format(filter_freq))
357 bp = lowpass(filter_freq, self.sampling_frequency)
358 strain = TimeSeries(self.time_domain_strain, sample_rate=self.sampling_frequency)
359 strain = strain.filter(bp, filtfilt=True)
360 self._time_domain_strain = strain.value
362 def create_power_spectral_density(
363 self, fft_length, overlap=0, name='unknown', outdir=None,
364 analysis_segment_start_time=None):
365 """ Use the time domain strain to generate a power spectral density
367 This create a Tukey-windowed power spectral density and writes it to a
368 PSD file.
370 Parameters
371 ==========
372 fft_length: float
373 Duration of the analysis segment.
374 overlap: float
375 Number of seconds of overlap between FFTs.
376 name: str
377 The name of the detector, used in storing the PSD. Defaults to
378 "unknown".
379 outdir: str
380 The output directory to write the PSD file too. If not given,
381 the PSD will not be written to file.
382 analysis_segment_start_time: float
383 The start time of the analysis segment, if given, this data will
384 be removed before creating the PSD.
386 Returns
387 =======
388 frequency_array, psd : array_like
389 The frequencies and power spectral density array
391 """
392 from gwpy.timeseries import TimeSeries
394 data = self.time_domain_strain
396 if analysis_segment_start_time is not None:
397 analysis_segment_end_time = analysis_segment_start_time + fft_length
398 inside = (analysis_segment_start_time > self.time_array[0] +
399 analysis_segment_end_time < self.time_array[-1])
400 if inside:
401 logger.info("Removing analysis segment data from the PSD data")
402 idxs = (
403 (self.time_array < analysis_segment_start_time) +
404 (self.time_array > analysis_segment_end_time))
405 data = data[idxs]
407 # WARNING this line can cause issues if the data is non-contiguous
408 strain = TimeSeries(data=data, sample_rate=self.sampling_frequency)
409 psd_alpha = 2 * self.roll_off / fft_length
410 logger.info(
411 "Tukey window PSD data with alpha={}, roll off={}".format(
412 psd_alpha, self.roll_off))
413 psd = strain.psd(
414 fftlength=fft_length, overlap=overlap, window=('tukey', psd_alpha))
416 if outdir:
417 psd_file = '{}/{}_PSD_{}_{}.txt'.format(outdir, name, self.start_time, self.duration)
418 with open('{}'.format(psd_file), 'w+') as opened_file:
419 for f, p in zip(psd.frequencies.value, psd.value):
420 opened_file.write('{} {}\n'.format(f, p))
422 return psd.frequencies.value, psd.value
424 def _infer_time_domain_dependence(
425 self, start_time, sampling_frequency, duration, time_array):
426 """ Helper function to figure out if the time_array, or
427 sampling_frequency and duration where given
428 """
429 self._infer_dependence(domain='time', array=time_array, duration=duration,
430 sampling_frequency=sampling_frequency, start_time=start_time)
432 def _infer_frequency_domain_dependence(
433 self, start_time, sampling_frequency, duration, frequency_array):
434 """ Helper function to figure out if the frequency_array, or
435 sampling_frequency and duration where given
436 """
438 self._infer_dependence(domain='frequency', array=frequency_array,
439 duration=duration, sampling_frequency=sampling_frequency, start_time=start_time)
441 def _infer_dependence(self, domain, array, duration, sampling_frequency, start_time):
442 if (sampling_frequency is not None) and (duration is not None):
443 if array is not None:
444 raise ValueError(
445 "You have given the sampling_frequency, duration, and "
446 "an array")
447 pass
448 elif array is not None:
449 if domain == 'time':
450 self.time_array = array
451 elif domain == 'frequency':
452 self.frequency_array = array
453 self.start_time = start_time
454 return
455 elif sampling_frequency is None or duration is None:
456 raise ValueError(
457 "You must provide both sampling_frequency and duration")
458 else:
459 raise ValueError(
460 "Insufficient information given to set arrays")
461 self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration,
462 sampling_frequency=sampling_frequency,
463 start_time=start_time)
465 def set_from_time_domain_strain(
466 self, time_domain_strain, sampling_frequency=None, duration=None,
467 start_time=0, time_array=None):
468 """ Set the strain data from a time domain strain array
470 This sets the time_domain_strain attribute, the frequency_domain_strain
471 is automatically calculated after a low-pass filter and Tukey window
472 is applied.
474 Parameters
475 ==========
476 time_domain_strain: array_like
477 An array of the time domain strain.
478 sampling_frequency: float
479 The sampling frequency (in Hz).
480 duration: float
481 The data duration (in s).
482 start_time: float
483 The GPS start-time of the data.
484 time_array: array_like
485 The array of times, if sampling_frequency and duration not
486 given.
488 """
489 self._infer_time_domain_dependence(start_time=start_time,
490 sampling_frequency=sampling_frequency,
491 duration=duration,
492 time_array=time_array)
494 logger.debug('Setting data using provided time_domain_strain')
495 if np.shape(time_domain_strain) == np.shape(self.time_array):
496 self._time_domain_strain = time_domain_strain
497 self._frequency_domain_strain = None
498 else:
499 raise ValueError("Data times do not match time array")
501 def set_from_gwpy_timeseries(self, time_series):
502 """ Set the strain data from a gwpy TimeSeries
504 This sets the time_domain_strain attribute, the frequency_domain_strain
505 is automatically calculated after a low-pass filter and Tukey window
506 is applied.
508 Parameters
509 ==========
510 time_series: gwpy.timeseries.timeseries.TimeSeries
511 The data to use
513 """
514 from gwpy.timeseries import TimeSeries
515 logger.debug('Setting data using provided gwpy TimeSeries object')
516 if not isinstance(time_series, TimeSeries):
517 raise ValueError("Input time_series is not a gwpy TimeSeries")
518 self._times_and_frequencies = \
519 CoupledTimeAndFrequencySeries(duration=time_series.duration.value,
520 sampling_frequency=time_series.sample_rate.value,
521 start_time=time_series.epoch.value)
522 self._time_domain_strain = time_series.value
523 self._frequency_domain_strain = None
524 self._channel = time_series.channel
526 @property
527 def channel(self):
528 return self._channel
530 def set_from_open_data(
531 self, name, start_time, duration=4, outdir='outdir', cache=True,
532 **kwargs):
533 """ Set the strain data from open LOSC data
535 This sets the time_domain_strain attribute, the frequency_domain_strain
536 is automatically calculated after a low-pass filter and Tukey window
537 is applied.
539 Parameters
540 ==========
541 name: str
542 Detector name, e.g., 'H1'.
543 start_time: float
544 Start GPS time of segment.
545 duration: float, optional
546 The total time (in seconds) to analyse. Defaults to 4s.
547 outdir: str
548 Directory where the psd files are saved
549 cache: bool, optional
550 Whether or not to store/use the acquired data.
551 **kwargs:
552 All keyword arguments are passed to
553 `gwpy.timeseries.TimeSeries.fetch_open_data()`.
555 """
557 timeseries = gwutils.get_open_strain_data(
558 name, start_time, start_time + duration, outdir=outdir, cache=cache,
559 **kwargs)
561 self.set_from_gwpy_timeseries(timeseries)
563 def set_from_csv(self, filename):
564 """ Set the strain data from a csv file
566 Parameters
567 ==========
568 filename: str
569 The path to the file to read in
571 """
572 from gwpy.timeseries import TimeSeries
573 timeseries = TimeSeries.read(filename, format='csv')
574 self.set_from_gwpy_timeseries(timeseries)
576 def set_from_frequency_domain_strain(
577 self, frequency_domain_strain, sampling_frequency=None,
578 duration=None, start_time=0, frequency_array=None):
579 """ Set the `frequency_domain_strain` from a numpy array
581 Parameters
582 ==========
583 frequency_domain_strain: array_like
584 The data to set.
585 sampling_frequency: float
586 The sampling frequency (in Hz).
587 duration: float
588 The data duration (in s).
589 start_time: float
590 The GPS start-time of the data.
591 frequency_array: array_like
592 The array of frequencies, if sampling_frequency and duration not
593 given.
595 """
597 self._infer_frequency_domain_dependence(start_time=start_time,
598 sampling_frequency=sampling_frequency,
599 duration=duration,
600 frequency_array=frequency_array)
602 logger.debug('Setting data using provided frequency_domain_strain')
603 if np.shape(frequency_domain_strain) == np.shape(self.frequency_array):
604 self._frequency_domain_strain = frequency_domain_strain
605 self.window_factor = 1
606 else:
607 raise ValueError("Data frequencies do not match frequency_array")
609 def set_from_power_spectral_density(
610 self, power_spectral_density, sampling_frequency, duration,
611 start_time=0):
612 """ Set the `frequency_domain_strain` by generating a noise realisation
614 Parameters
615 ==========
616 power_spectral_density: bilby.gw.detector.PowerSpectralDensity
617 A PowerSpectralDensity object used to generate the data
618 sampling_frequency: float
619 The sampling frequency (in Hz)
620 duration: float
621 The data duration (in s)
622 start_time: float
623 The GPS start-time of the data
625 """
627 self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration,
628 sampling_frequency=sampling_frequency,
629 start_time=start_time)
630 logger.debug(
631 'Setting data using noise realization from provided'
632 'power_spectal_density')
633 frequency_domain_strain, frequency_array = \
634 power_spectral_density.get_noise_realisation(
635 self.sampling_frequency, self.duration)
637 if np.array_equal(frequency_array, self.frequency_array):
638 self._frequency_domain_strain = frequency_domain_strain
639 else:
640 raise ValueError("Data frequencies do not match frequency_array")
642 def set_from_zero_noise(self, sampling_frequency, duration, start_time=0):
643 """ Set the `frequency_domain_strain` to zero noise
645 Parameters
646 ==========
647 sampling_frequency: float
648 The sampling frequency (in Hz)
649 duration: float
650 The data duration (in s)
651 start_time: float
652 The GPS start-time of the data
654 """
656 self._times_and_frequencies = CoupledTimeAndFrequencySeries(duration=duration,
657 sampling_frequency=sampling_frequency,
658 start_time=start_time)
659 logger.debug('Setting zero noise data')
660 self._frequency_domain_strain = np.zeros_like(self.frequency_array,
661 dtype=complex)
663 def set_from_frame_file(
664 self, frame_file, sampling_frequency, duration, start_time=0,
665 channel=None, buffer_time=1):
666 """ Set the `frequency_domain_strain` from a frame fiile
668 Parameters
669 ==========
670 frame_file: str
671 File from which to load data.
672 channel: str
673 Channel to read from frame.
674 sampling_frequency: float
675 The sampling frequency (in Hz)
676 duration: float
677 The data duration (in s)
678 start_time: float
679 The GPS start-time of the data
680 buffer_time: float
681 Read in data with `start_time-buffer_time` and
682 `start_time+duration+buffer_time`
684 """
686 self._times_and_frequencies = CoupledTimeAndFrequencySeries(
687 duration=duration, sampling_frequency=sampling_frequency,
688 start_time=start_time)
690 logger.info('Reading data from frame file {}'.format(frame_file))
691 strain = gwutils.read_frame_file(
692 frame_file, start_time=start_time, end_time=start_time + duration,
693 buffer_time=buffer_time, channel=channel,
694 resample=sampling_frequency)
696 self.set_from_gwpy_timeseries(strain)
698 def set_from_channel_name(self, channel, duration, start_time, sampling_frequency):
699 """ Set the `frequency_domain_strain` by fetching from given channel
700 using gwpy.TimesSeries.get(), which dynamically accesses either frames
701 on disk, or a remote NDS2 server to find and return data. This function
702 also verifies that the specified channel is given in the correct format.
704 Parameters
705 ==========
706 channel: str
707 Channel to look for using gwpy in the format `IFO:Channel`
708 duration: float
709 The data duration (in s)
710 start_time: float
711 The GPS start-time of the data
712 sampling_frequency: float
713 The sampling frequency (in Hz)
715 """
716 from gwpy.timeseries import TimeSeries
717 channel_comp = channel.split(':')
718 if len(channel_comp) != 2:
719 raise IndexError('Channel name must have format `IFO:Channel`')
721 self._times_and_frequencies = CoupledTimeAndFrequencySeries(
722 duration=duration, sampling_frequency=sampling_frequency,
723 start_time=start_time)
725 logger.info('Fetching data using channel {}'.format(channel))
726 strain = TimeSeries.get(channel, start_time, start_time + duration)
727 strain = strain.resample(sampling_frequency)
729 self.set_from_gwpy_timeseries(strain)
732class Notch(object):
733 def __init__(self, minimum_frequency, maximum_frequency):
734 """ A notch object storing the maximum and minimum frequency of the notch
736 Parameters
737 ==========
738 minimum_frequency, maximum_frequency: float
739 The minimum and maximum frequency of the notch
741 """
743 if 0 < minimum_frequency < maximum_frequency < np.inf:
744 self.minimum_frequency = minimum_frequency
745 self.maximum_frequency = maximum_frequency
746 else:
747 msg = ("Your notch minimum_frequency {} and maximum_frequency {} are invalid"
748 .format(minimum_frequency, maximum_frequency))
749 raise ValueError(msg)
751 def get_idxs(self, frequency_array):
752 """ Get a boolean mask for the frequencies in frequency_array in the notch
754 Parameters
755 ==========
756 frequency_array: np.ndarray
757 An array of frequencies
759 Returns
760 =======
761 idxs: np.ndarray
762 An array of booleans which are True for frequencies in the notch
764 """
765 lower = (frequency_array > self.minimum_frequency)
766 upper = (frequency_array < self.maximum_frequency)
767 return lower & upper
769 def check_frequency(self, freq):
770 """ Check if freq is inside the notch
772 Parameters
773 ==========
774 freq: float
775 The frequency to check
777 Returns
778 =======
779 True/False:
780 If freq inside the notch, return True, else False
781 """
783 if self.minimum_frequency < freq < self.maximum_frequency:
784 return True
785 else:
786 return False
789class NotchList(list):
790 def __init__(self, notch_list):
791 """ A list of notches
793 Parameters
794 ==========
795 notch_list: list
796 A list of length-2 tuples of the (max, min) frequency for the
797 notches.
799 Raises
800 ======
801 ValueError
802 If the list is malformed.
803 """
805 if notch_list is not None:
806 for notch in notch_list:
807 if isinstance(notch, tuple) and len(notch) == 2:
808 self.append(Notch(*notch))
809 else:
810 msg = "notch_list {} is malformed".format(notch_list)
811 raise ValueError(msg)
813 def check_frequency(self, freq):
814 """ Check if freq is inside the notch list
816 Parameters
817 ==========
818 freq: float
819 The frequency to check
821 Returns
822 =======
823 True/False:
824 If freq inside any of the notches, return True, else False
825 """
827 for notch in self:
828 if notch.check_frequency(freq):
829 return True
830 return False