Coverage for parallel_bilby/analysis/analysis_run.py: 81%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2import os
3import pickle
5import bilby
6import dynesty
7import numpy as np
8from bilby.core.sampler.base_sampler import _SamplingContainer
9from bilby.core.sampler.dynesty import DynestySetupError, _set_sampling_kwargs
10from bilby.core.sampler.dynesty_utils import (
11 AcceptanceTrackingRWalk,
12 ACTTrackingRWalk,
13 FixedRWalk,
14 LivePointSampler,
15 MultiEllipsoidLivePointSampler,
16)
17from bilby.core.utils import logger
18from bilby_pipe.utils import convert_string_to_list
20from .likelihood import setup_likelihood
23class AnalysisRun(object):
24 """
25 An object with methods for driving the sampling run.
27 Parameters: arguments to set the output path and control the dynesty sampler.
28 """
30 def __init__(
31 self,
32 data_dump,
33 outdir=None,
34 label=None,
35 dynesty_sample="acceptance-walk",
36 nlive=5,
37 dynesty_bound="live",
38 walks=100,
39 maxmcmc=5000,
40 naccept=60,
41 nact=2,
42 facc=0.5,
43 min_eff=10,
44 enlarge=1.5,
45 sampling_seed=0,
46 proposals=None,
47 bilby_zero_likelihood_mode=False,
48 ):
49 self.maxmcmc = maxmcmc
50 self.nact = nact
51 self.naccept = naccept
52 self.proposals = convert_string_to_list(proposals)
54 # Read data dump from the pickle file
55 with open(data_dump, "rb") as file:
56 data_dump = pickle.load(file)
58 ifo_list = data_dump["ifo_list"]
59 waveform_generator = data_dump["waveform_generator"]
60 waveform_generator.start_time = ifo_list[0].time_array[0]
61 args = data_dump["args"]
62 injection_parameters = data_dump.get("injection_parameters", None)
64 args.weight_file = data_dump["meta_data"].get("weight_file", None)
66 # If the run dir has not been specified, get it from the args
67 if outdir is None:
68 outdir = args.outdir
69 else:
70 # Create the run dir
71 os.makedirs(outdir, exist_ok=True)
73 # If the label has not been specified, get it from the args
74 if label is None:
75 label = args.label
77 priors = bilby.gw.prior.PriorDict.from_json(data_dump["prior_file"])
79 logger.setLevel(logging.WARNING)
80 likelihood = setup_likelihood(
81 interferometers=ifo_list,
82 waveform_generator=waveform_generator,
83 priors=priors,
84 args=args,
85 )
86 priors.convert_floats_to_delta_functions()
87 logger.setLevel(logging.INFO)
89 sampling_keys = []
90 for p in priors:
91 if isinstance(priors[p], bilby.core.prior.Constraint):
92 continue
93 elif priors[p].is_fixed:
94 likelihood.parameters[p] = priors[p].peak
95 else:
96 sampling_keys.append(p)
98 periodic = []
99 reflective = []
100 for ii, key in enumerate(sampling_keys):
101 if priors[key].boundary == "periodic":
102 logger.debug(f"Setting periodic boundary for {key}")
103 periodic.append(ii)
104 elif priors[key].boundary == "reflective":
105 logger.debug(f"Setting reflective boundary for {key}")
106 reflective.append(ii)
108 if len(periodic) == 0:
109 periodic = None
110 if len(reflective) == 0:
111 reflective = None
113 self.init_sampler_kwargs = dict(
114 nlive=nlive,
115 sample=dynesty_sample,
116 bound=dynesty_bound,
117 walks=walks,
118 facc=facc,
119 first_update=dict(min_eff=min_eff, min_ncall=2 * nlive),
120 enlarge=enlarge,
121 )
123 self._set_sampling_method()
125 # Create a random generator, which is saved across restarts
126 # This ensures that runs are fully deterministic, which is important
127 # for reproducibility
128 self.sampling_seed = sampling_seed
129 self.rstate = np.random.Generator(np.random.PCG64(self.sampling_seed))
130 logger.debug(
131 f"Setting random state = {self.rstate} (seed={self.sampling_seed})"
132 )
134 self.outdir = outdir
135 self.label = label
136 self.data_dump = data_dump
137 self.priors = priors
138 self.sampling_keys = sampling_keys
139 self.likelihood = likelihood
140 self.zero_likelihood_mode = bilby_zero_likelihood_mode
141 self.periodic = periodic
142 self.reflective = reflective
143 self.args = args
144 self.injection_parameters = injection_parameters
145 self.nlive = nlive
147 def _set_sampling_method(self):
149 sample = self.init_sampler_kwargs["sample"]
150 bound = self.init_sampler_kwargs["bound"]
152 _set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept))
154 if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [
155 "live",
156 "live-multi",
157 ]:
158 logger.info(
159 "Live-point based bound method requested with dynesty sample "
160 f"'{sample}', overwriting to 'multi'"
161 )
162 self.init_sampler_kwargs["bound"] = "multi"
163 elif bound == "live":
164 dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler
165 elif bound == "live-multi":
166 dynesty.dynamicsampler._SAMPLERS[
167 "live-multi"
168 ] = MultiEllipsoidLivePointSampler
169 elif sample == "acceptance-walk":
170 raise DynestySetupError(
171 "bound must be set to live or live-multi for sample=acceptance-walk"
172 )
173 elif self.proposals is None:
174 logger.warning(
175 "No proposals specified using dynesty sampling, defaulting "
176 "to 'volumetric'."
177 )
178 self.proposals = ["volumetric"]
179 _SamplingContainer.proposals = self.proposals
180 elif "diff" in self.proposals:
181 raise DynestySetupError(
182 "bound must be set to live or live-multi to use differential "
183 "evolution proposals"
184 )
186 if sample == "rwalk":
187 logger.info(
188 "Using the bilby-implemented rwalk sample method with ACT estimated walks. "
189 f"An average of {2 * self.nact} steps will be accepted up to chain length "
190 f"{self.maxmcmc}."
191 )
192 if self.init_sampler_kwargs["walks"] > self.maxmcmc:
193 raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)")
194 if self.nact < 1:
195 raise DynestySetupError("Unable to run with nact < 1")
196 dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk()
197 elif sample == "acceptance-walk":
198 logger.info(
199 "Using the bilby-implemented rwalk sampling with an average of "
200 f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}"
201 )
202 dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk()
203 elif sample == "act-walk":
204 logger.info(
205 "Using the bilby-implemented rwalk sampling tracking the "
206 f"autocorrelation function and thinning by "
207 f"{self.nact} with maximum length {self.nact * self.maxmcmc}"
208 )
209 dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk()
210 elif sample == "rwalk_dynesty":
211 sample = sample.strip("_dynesty")
212 self.init_sampler_kwargs["sample"] = sample
213 logger.info(f"Using the dynesty-implemented {sample} sample method")
215 def prior_transform_function(self, u_array):
216 """
217 Calls the bilby rescaling function on an array of values
219 Parameters
220 ----------
221 u_array: (float, array-like)
222 The values to rescale
224 Returns
225 -------
226 (float, array-like)
227 The rescaled values
229 """
230 return self.priors.rescale(self.sampling_keys, u_array)
232 def log_likelihood_function(self, v_array):
233 """
234 Calculates the log(likelihood)
236 Parameters
237 ----------
238 u_array: (float, array-like)
239 The values to rescale
241 Returns
242 -------
243 (float, array-like)
244 The rescaled values
246 """
247 if self.zero_likelihood_mode:
248 return 0
249 parameters = {key: v for key, v in zip(self.sampling_keys, v_array)}
250 if self.priors.evaluate_constraints(parameters) > 0:
251 self.likelihood.parameters.update(parameters)
252 return (
253 self.likelihood.log_likelihood()
254 - self.likelihood.noise_log_likelihood()
255 )
256 else:
257 return np.nan_to_num(-np.inf)
259 def log_prior_function(self, v_array):
260 """
261 Calculates the log of the prior
263 Parameters
264 ----------
265 v_array: (float, array-like)
266 The prior values
268 Returns
269 -------
270 (float, array-like)
271 The log probability of the values
273 """
274 params = {key: t for key, t in zip(self.sampling_keys, v_array)}
275 return self.priors.ln_prob(params)
277 def get_initial_points_from_prior(self, pool, calculate_likelihood=True):
278 """
279 Generates a set of initial points, drawn from the prior
281 Parameters
282 ----------
283 pool: schwimmbad.MPIPool
284 Schwimmbad pool for MPI parallelisation
285 (pbilby implements a modified version: MPIPoolFast)
287 calculate_likelihood: bool
288 Option to calculate the likelihood for the generated points
289 (default: True)
291 Returns
292 -------
293 (numpy.ndarraym, numpy.ndarray, numpy.ndarray, None)
294 Returns a tuple (unit, theta, logl, blob)
295 unit: point in the unit cube
296 theta: scaled value
297 logl: log(likelihood)
298 blob: None
300 """
301 # Create a new rstate for each point, otherwise each task will generate
302 # the same random number, and the rstate on master will not be incremented.
303 # The argument to self.rstate.integers() is a very large integer.
304 # These rstates aren't used after this map, but each time they are created,
305 # a different (but deterministic) seed is used.
306 sg = np.random.SeedSequence(self.rstate.integers(9223372036854775807))
307 map_rstates = [
308 np.random.Generator(np.random.PCG64(n)) for n in sg.spawn(self.nlive)
309 ]
310 ndim = len(self.sampling_keys)
312 args_list = [
313 (
314 self.prior_transform_function,
315 self.log_prior_function,
316 self.log_likelihood_function,
317 ndim,
318 calculate_likelihood,
319 map_rstates[i],
320 )
321 for i in range(self.nlive)
322 ]
323 initial_points = pool.map(self.get_initial_point_from_prior, args_list)
324 u_list = [point[0] for point in initial_points]
325 v_list = [point[1] for point in initial_points]
326 l_list = [point[2] for point in initial_points]
327 blobs = None
329 return np.array(u_list), np.array(v_list), np.array(l_list), blobs
331 @staticmethod
332 def get_initial_point_from_prior(args):
333 """
334 Draw initial points from the prior subject to constraints applied both to
335 the prior and the likelihood.
337 We remove any points where the likelihood or prior is infinite or NaN.
339 The `log_likelihood_function` often converts infinite values to large
340 finite values so we catch those.
341 """
342 (
343 prior_transform_function,
344 log_prior_function,
345 log_likelihood_function,
346 ndim,
347 calculate_likelihood,
348 rstate,
349 ) = args
350 bad_values = [np.inf, np.nan_to_num(np.inf), np.nan]
351 while True:
352 unit = rstate.random(ndim)
353 theta = prior_transform_function(unit)
355 if abs(log_prior_function(theta)) not in bad_values:
356 if calculate_likelihood:
357 logl = log_likelihood_function(theta)
358 if abs(logl) not in bad_values:
359 return unit, theta, logl
360 else:
361 return unit, theta, np.nan
363 def get_nested_sampler(self, live_points, pool, pool_size):
364 """
365 Returns the dynested nested sampler, getting most arguments
366 from the object's attributes
368 Parameters
369 ----------
370 live_points: (numpy.ndarraym, numpy.ndarray, numpy.ndarray)
371 The set of live points, in the same format as returned by
372 get_initial_points_from_prior
374 pool: schwimmbad.MPIPool
375 Schwimmbad pool for MPI parallelisation
376 (pbilby implements a modified version: MPIPoolFast)
378 pool_size: int
379 Number of workers in the pool
381 Returns
382 -------
383 dynesty.NestedSampler
385 """
386 ndim = len(self.sampling_keys)
387 sampler = dynesty.NestedSampler(
388 self.log_likelihood_function,
389 self.prior_transform_function,
390 ndim,
391 pool=pool,
392 queue_size=pool_size,
393 periodic=self.periodic,
394 reflective=self.reflective,
395 live_points=live_points,
396 rstate=self.rstate,
397 use_pool=dict(
398 update_bound=True,
399 propose_point=True,
400 prior_transform=True,
401 loglikelihood=True,
402 ),
403 **self.init_sampler_kwargs,
404 )
406 return sampler