Coverage for parallel_bilby/analysis/likelihood.py: 64%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import inspect
2from importlib import import_module
4import bilby
5import bilby_pipe
6import numpy as np
7from bilby.core.utils import logger
10def reorder_loglikelihoods(unsorted_loglikelihoods, unsorted_samples, sorted_samples):
11 """Reorders the stored log-likelihood after they have been reweighted
13 This creates a sorting index by matching the reweights `result.samples`
14 against the raw samples, then uses this index to sort the
15 loglikelihoods
17 Parameters
18 ----------
19 sorted_samples, unsorted_samples: array-like
20 Sorted and unsorted values of the samples. These should be of the
21 same shape and contain the same sample values, but in different
22 orders
23 unsorted_loglikelihoods: array-like
24 The loglikelihoods corresponding to the unsorted_samples
26 Returns
27 -------
28 sorted_loglikelihoods: array-like
29 The loglikelihoods reordered to match that of the sorted_samples
32 """
34 idxs = []
35 for ii in range(len(unsorted_loglikelihoods)):
36 idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0]
37 if len(idx) > 1:
38 print(
39 "Multiple likelihood matches found between sorted and "
40 "unsorted samples. Taking the first match."
41 )
42 idxs.append(idx[0])
43 return unsorted_loglikelihoods[idxs]
46def roq_likelihood_kwargs(args):
47 """Return the kwargs required for the ROQ setup
49 Parameters
50 ----------
51 args: Namespace
52 The parser arguments
54 Returns
55 -------
56 kwargs: dict
57 A dictionary of the required kwargs
59 """
61 kwargs = dict(
62 weights=None,
63 roq_params=None,
64 linear_matrix=None,
65 quadratic_matrix=None,
66 roq_scale_factor=args.roq_scale_factor,
67 )
68 if hasattr(args, "likelihood_roq_params") and hasattr(
69 args, "likelihood_roq_weights"
70 ):
71 kwargs["roq_params"] = args.likelihood_roq_params
72 kwargs["weights"] = args.likelihood_roq_weights
73 elif hasattr(args, "roq_folder") and args.roq_folder is not None:
74 logger.info(f"Loading ROQ weights from {args.roq_folder}, {args.weight_file}")
75 kwargs["roq_params"] = np.genfromtxt(
76 args.roq_folder + "/params.dat", names=True
77 )
78 kwargs["weights"] = args.weight_file
79 elif hasattr(args, "roq_linear_matrix") and args.roq_linear_matrix is not None:
80 logger.info(f"Loading linear_matrix from {args.roq_linear_matrix}")
81 logger.info(f"Loading quadratic_matrix from {args.roq_quadratic_matrix}")
82 kwargs["linear_matrix"] = args.roq_linear_matrix
83 kwargs["quadratic_matrix"] = args.roq_quadratic_matrix
84 return kwargs
87def setup_likelihood(interferometers, waveform_generator, priors, args):
88 """Takes the kwargs and sets up and returns either an ROQ GW or GW likelihood.
90 Parameters
91 ----------
92 interferometers: bilby.gw.detectors.InterferometerList
93 The pre-loaded bilby IFO
94 waveform_generator: bilby.gw.waveform_generator.LALCBCWaveformGenerator
95 The waveform generation
96 priors: dict
97 The priors, used for setting up marginalization
98 args: Namespace
99 The parser arguments
102 Returns
103 -------
104 likelihood: bilby.gw.likelihood.GravitationalWaveTransient
105 The likelihood (either GravitationalWaveTransient or ROQGravitationalWaveTransient)
107 """
109 likelihood_kwargs = dict(
110 interferometers=interferometers,
111 waveform_generator=waveform_generator,
112 priors=priors,
113 phase_marginalization=args.phase_marginalization,
114 distance_marginalization=args.distance_marginalization,
115 distance_marginalization_lookup_table=args.distance_marginalization_lookup_table,
116 time_marginalization=args.time_marginalization,
117 reference_frame=args.reference_frame,
118 time_reference=args.time_reference,
119 )
121 if args.likelihood_type == "GravitationalWaveTransient":
122 Likelihood = bilby.gw.likelihood.GravitationalWaveTransient
123 likelihood_kwargs.update(jitter_time=args.jitter_time)
125 elif args.likelihood_type == "ROQGravitationalWaveTransient":
126 Likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient
128 if args.time_marginalization:
129 logger.warning(
130 "Time marginalization not implemented for "
131 "ROQGravitationalWaveTransient: option ignored"
132 )
134 likelihood_kwargs.pop("time_marginalization", None)
135 likelihood_kwargs.pop("jitter_time", None)
136 likelihood_kwargs.update(roq_likelihood_kwargs(args))
137 elif "." in args.likelihood_type:
138 split_path = args.likelihood_type.split(".")
139 module = ".".join(split_path[:-1])
140 likelihood_class = split_path[-1]
141 Likelihood = getattr(import_module(module), likelihood_class)
142 likelihood_kwargs.update(
143 bilby_pipe.utils.convert_string_to_dict(args.extra_likelihood_kwargs)
144 )
145 if "roq" in args.likelihood_type.lower():
146 likelihood_kwargs.pop("time_marginalization", None)
147 likelihood_kwargs.pop("jitter_time", None)
148 likelihood_kwargs.update(args.roq_likelihood_kwargs)
149 else:
150 raise ValueError("Unknown Likelihood class {}")
152 likelihood_kwargs = {
153 key: likelihood_kwargs[key]
154 for key in likelihood_kwargs
155 if key in inspect.getfullargspec(Likelihood.__init__).args
156 }
158 logger.info(
159 f"Initialise likelihood {Likelihood} with kwargs: \n{likelihood_kwargs}"
160 )
162 likelihood = Likelihood(**likelihood_kwargs)
163 return likelihood