Coverage for bilby/core/grid.py: 88%
184 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import json
2import os
4import numpy as np
6from .prior import Prior, PriorDict
7from .utils import (
8 logtrapzexp, check_directory_exists_and_if_not_mkdir, logger,
9 BilbyJsonEncoder, load_json, move_old_file
10)
11from .result import FileMovedError
14def grid_file_name(outdir, label, gzip=False):
15 """ Returns the standard filename used for a grid file
17 Parameters
18 ==========
19 outdir: str
20 Name of the output directory
21 label: str
22 Naming scheme of the output file
23 gzip: bool, optional
24 Set to True to append `.gz` to the extension for saving in gzipped format
26 Returns
27 =======
28 str: File name of the output file
29 """
30 if gzip:
31 return os.path.join(outdir, '{}_grid.json.gz'.format(label))
32 else:
33 return os.path.join(outdir, '{}_grid.json'.format(label))
36class Grid(object):
38 def __init__(self, likelihood=None, priors=None, grid_size=101,
39 save=False, label='no_label', outdir='.', gzip=False):
40 """
42 Parameters
43 ==========
44 likelihood: bilby.likelihood.Likelihood
45 priors: bilby.prior.PriorDict
46 grid_size: int, list, dict
47 Size of the grid, can be any of
48 - int: all dimensions will have equal numbers of points
49 - list: dimensions will use these points/this number of points in
50 order of priors
51 - dict: as for list
52 save: bool
53 Set whether to save the results of the grid
54 label: str
55 The label for the filename to which the grid is saved
56 outdir: str
57 The output directory to which the grid will be saved
58 gzip: bool
59 Set whether to gzip the output grid file
60 """
62 if priors is None:
63 priors = dict()
64 self.likelihood = likelihood
65 self.priors = PriorDict(priors)
66 self.n_dims = len(priors)
67 self.parameter_names = list(self.priors.keys())
69 self.sample_points = dict()
70 self._get_sample_points(grid_size)
71 # evaluate the prior on the grid points
72 if self.n_dims > 0:
73 self._ln_prior = self.priors.ln_prob(
74 {key: self.mesh_grid[i].flatten() for i, key in
75 enumerate(self.parameter_names)}, axis=0).reshape(
76 self.mesh_grid[0].shape)
77 self._ln_likelihood = None
79 # evaluate the likelihood on the grid points
80 if likelihood is not None and self.n_dims > 0:
81 self._evaluate()
83 self.save = save
84 self.label = None
85 self.outdir = None
86 if self.save:
87 if isinstance(label, str):
88 self.label = label
89 if isinstance(outdir, str):
90 self.outdir = os.path.abspath(outdir)
91 self.save_to_file(gzip=gzip)
93 @property
94 def ln_prior(self):
95 return self._ln_prior
97 @property
98 def prior(self):
99 return np.exp(self.ln_prior)
101 @property
102 def ln_likelihood(self):
103 if self._ln_likelihood is None:
104 self._evaluate()
105 return self._ln_likelihood
107 @property
108 def ln_posterior(self):
109 return self.ln_likelihood + self.ln_prior
111 def marginalize(self, log_array, parameters=None, not_parameters=None):
112 """
113 Marginalize over a list of parameters.
115 Parameters
116 ==========
117 log_array: array_like
118 A :class:`numpy.ndarray` of log likelihood/posterior values.
119 parameters: list, str
120 A list, or single string, of parameters to marginalize over. If None
121 then all parameters will be marginalized over.
122 not_parameters: list, str
123 Instead of a list of parameters to marginalize over you can list
124 the set of parameter to *not* marginalize over.
126 Returns
127 =======
128 out_array: array_like
129 An array containing the marginalized log likelihood/posterior.
130 """
132 if parameters is None:
133 params = list(self.parameter_names)
135 if not_parameters is not None:
136 if isinstance(not_parameters, str):
137 not_params = [not_parameters]
138 elif isinstance(not_parameters, list):
139 not_params = not_parameters
140 else:
141 raise TypeError("Parameters names must be a list or string")
143 for name in list(params):
144 if name in not_params:
145 params.remove(name)
146 elif isinstance(parameters, str):
147 params = [parameters]
148 elif isinstance(parameters, list):
149 params = parameters
150 else:
151 raise TypeError("Parameters names must be a list or string")
153 out_array = log_array.copy()
154 names = list(self.parameter_names)
156 for name in params:
157 out_array = self._marginalize_single(out_array, name, names)
159 return out_array
161 def _marginalize_single(self, log_array, name, non_marg_names=None):
162 """
163 Marginalize the log likelihood/posterior over a single given parameter.
165 Parameters
166 ==========
167 log_array: array_like
168 A :class:`numpy.ndarray` of log likelihood/posterior values.
169 name: str
170 The name of the parameter to marginalize over.
171 non_marg_names: list
172 A list of parameter names that have not been marginalized over.
174 Returns
175 =======
176 out: array_like
177 An array containing the marginalized log likelihood/posterior.
178 """
180 if name not in self.parameter_names:
181 raise ValueError("'{}' is not a recognised "
182 "parameter".format(name))
184 if non_marg_names is None:
185 non_marg_names = list(self.parameter_names)
187 axis = non_marg_names.index(name)
188 non_marg_names.remove(name)
190 places = self.sample_points[name]
192 if len(places) > 1:
193 dx = np.diff(places)
194 out = np.apply_along_axis(
195 logtrapzexp, axis, log_array, dx
196 )
197 else:
198 # no marginalisation required, just remove the singleton dimension
199 z = log_array.shape
200 q = np.arange(0, len(z)).astype(int) != axis
201 out = np.reshape(log_array, tuple((np.array(list(z)))[q]))
203 return out
205 @property
206 def ln_evidence(self):
207 return self.marginalize(self.ln_posterior)
209 @property
210 def log_evidence(self):
211 return self.ln_evidence
213 @property
214 def log_noise_evidence(self):
215 return self.ln_noise_evidence
217 def marginalize_ln_likelihood(self, parameters=None, not_parameters=None):
218 """
219 Marginalize the ln likelihood over either the specified parameter or
220 all but the specified "not_parameter". If neither is specified the
221 ln likelihood will be fully marginalized over.
223 Parameters
224 ==========
225 parameters: str, list, optional
226 Name of, or list of names of, the parameter(s) to marginalize over.
227 not_parameters: str, optional
228 Name of, or list of names of, the parameter(s) to not marginalize over.
230 Returns
231 =======
232 array-like:
233 The marginalized ln likelihood.
234 """
235 return self.marginalize(self.ln_likelihood, parameters=parameters,
236 not_parameters=not_parameters)
238 def marginalize_ln_posterior(self, parameters=None, not_parameters=None):
239 """
240 Marginalize the ln posterior over either the specified parameter or all
241 but the specified "not_parameter". If neither is specified the
242 ln posterior will be fully marginalized over.
244 Parameters
245 ==========
246 parameters: str, list, optional
247 Name of, or list of names of, the parameter(s) to marginalize over.
248 not_parameters: str, optional
249 Name of, or list of names of, the parameter(s) to not marginalize over.
251 Returns
252 =======
253 array-like:
254 The marginalized ln posterior.
255 """
256 return self.marginalize(self.ln_posterior, parameters=parameters,
257 not_parameters=not_parameters)
259 def marginalize_likelihood(self, parameters=None, not_parameters=None):
260 """
261 Marginalize the likelihood over either the specified parameter or all
262 but the specified "not_parameter". If neither is specified the
263 likelihood will be fully marginalized over.
265 Parameters
266 ==========
267 parameters: str, list, optional
268 Name of, or list of names of, the parameter(s) to marginalize over.
269 not_parameters: str, optional
270 Name of, or list of names of, the parameter(s) to not marginalize over.
272 Returns
273 =======
274 array-like:
275 The marginalized likelihood.
276 """
277 ln_like = self.marginalize(self.ln_likelihood, parameters=parameters,
278 not_parameters=not_parameters)
279 # NOTE: the output will not be properly normalised
280 return np.exp(ln_like - np.max(ln_like))
282 def marginalize_posterior(self, parameters=None, not_parameters=None):
283 """
284 Marginalize the posterior over either the specified parameter or all
285 but the specified "not_parameters". If neither is specified the
286 posterior will be fully marginalized over.
288 Parameters
289 ==========
290 parameters: str, list, optional
291 Name of, or list of names of, the parameter(s) to marginalize over.
292 not_parameters: str, optional
293 Name of, or list of names of, the parameter(s) to not marginalize over.
295 Returns
296 =======
297 array-like:
298 The marginalized posterior.
299 """
300 ln_post = self.marginalize(self.ln_posterior, parameters=parameters,
301 not_parameters=not_parameters)
302 # NOTE: the output will not be properly normalised
303 return np.exp(ln_post - np.max(ln_post))
305 def _evaluate(self):
306 self._ln_likelihood = np.empty(self.mesh_grid[0].shape)
307 self._evaluate_recursion(0)
308 self.ln_noise_evidence = self.likelihood.noise_log_likelihood()
310 def _evaluate_recursion(self, dimension):
311 if dimension == self.n_dims:
312 current_point = tuple([[int(np.where(
313 self.likelihood.parameters[name] ==
314 self.sample_points[name])[0])] for name in self.parameter_names])
315 self._ln_likelihood[current_point] = self.likelihood.log_likelihood()
316 else:
317 name = self.parameter_names[dimension]
318 for ii in range(self._ln_likelihood.shape[dimension]):
319 self.likelihood.parameters[name] = self.sample_points[name][ii]
320 self._evaluate_recursion(dimension + 1)
322 def _get_sample_points(self, grid_size):
323 for ii, key in enumerate(self.parameter_names):
324 if isinstance(self.priors[key], Prior):
325 if isinstance(grid_size, int):
326 self.sample_points[key] = self.priors[key].rescale(
327 np.linspace(0, 1, grid_size))
328 elif isinstance(grid_size, list):
329 if isinstance(grid_size[ii], int):
330 self.sample_points[key] = self.priors[key].rescale(
331 np.linspace(0, 1, grid_size[ii]))
332 else:
333 self.sample_points[key] = grid_size[ii]
334 elif isinstance(grid_size, dict):
335 if isinstance(grid_size[key], int):
336 self.sample_points[key] = self.priors[key].rescale(
337 np.linspace(0, 1, grid_size[key]))
338 else:
339 self.sample_points[key] = grid_size[key]
340 else:
341 raise TypeError("Unrecognized 'grid_size' type")
343 # set the mesh of points
344 self.mesh_grid = np.meshgrid(
345 *(self.sample_points[key] for key in self.parameter_names),
346 indexing='ij')
348 def _get_save_data_dictionary(self):
349 # This list defines all the parameters saved in the grid object
350 save_attrs = [
351 'label', 'outdir', 'parameter_names', 'n_dims', 'priors',
352 'sample_points', 'ln_likelihood', 'ln_evidence',
353 'ln_noise_evidence']
354 dictionary = dict()
355 for attr in save_attrs:
356 try:
357 dictionary[attr] = getattr(self, attr)
358 except ValueError as e:
359 logger.debug("Unable to save {}, message: {}".format(attr, e))
360 pass
361 return dictionary
363 def _safe_outdir_creation(self, outdir=None, caller_func=None):
364 if outdir is None:
365 outdir = self.outdir
366 try:
367 check_directory_exists_and_if_not_mkdir(outdir)
368 except PermissionError:
369 raise FileMovedError("Can not write in the out directory.\n"
370 "Did you move the here file from another system?\n"
371 "Try calling " + caller_func.__name__ + " with the 'outdir' "
372 "keyword argument, e.g. " + caller_func.__name__ + "(outdir='.')")
373 return outdir
375 def save_to_file(self, filename=None, overwrite=False, outdir=None,
376 gzip=False):
377 """
378 Writes the Grid to a file.
380 Parameters
381 ==========
382 filename: str, optional
383 Filename to write to (overwrites the default)
384 overwrite: bool, optional
385 Whether or not to overwrite an existing result file.
386 default=False
387 outdir: str, optional
388 Path to the outdir. Default is the one stored in the Grid object.
389 gzip: bool, optional
390 If true this will gzip the resulting file and add '.gz' to the file
391 extension.
392 """
394 outdir = self._safe_outdir_creation(outdir, self.save_to_file)
395 if filename is None:
396 if self.label is None:
397 raise ValueError("'label' for the output file name is not given")
399 filename = grid_file_name(outdir, self.label, gzip)
401 move_old_file(filename, overwrite)
402 dictionary = self._get_save_data_dictionary()
404 try:
405 dictionary["priors"] = dictionary["priors"]._get_json_dict()
406 if gzip or (os.path.splitext(filename)[-1] == '.gz'):
407 import gzip
408 # encode to a string
409 json_str = json.dumps(dictionary, cls=BilbyJsonEncoder).encode('utf-8')
410 with gzip.GzipFile(filename, 'w') as file:
411 file.write(json_str)
412 else:
413 with open(filename, 'w') as file:
414 json.dump(dictionary, file, indent=2, cls=BilbyJsonEncoder)
415 except Exception as e:
416 logger.error("\n\n Saving the data has failed with the "
417 "following message:\n {} \n\n".format(e))
419 @classmethod
420 def read(cls, filename=None, outdir=None, label=None, gzip=False):
421 """ Read in a saved .json grid file
423 Parameters
424 ==========
425 filename: str
426 If given, try to load from this filename
427 outdir, label: str
428 If given, use the default naming convention for saved results file
429 gzip: bool
430 If given, whether the file is gzipped or not (only required if the
431 file is gzipped, but does not have the standard '.gz' file
432 extension)
434 Returns
435 =======
436 grid: bilby.core.grid.Grid
438 Raises
439 ======
440 ValueError: If no filename is given and either outdir or label is None
441 If no bilby.core.grid.Grid is found in the path
443 """
445 if filename is None:
446 if (outdir is None) and (label is None):
447 raise ValueError("No information given to load file")
448 else:
449 filename = grid_file_name(outdir, label, gzip)
451 if os.path.isfile(filename):
452 dictionary = load_json(filename, gzip)
453 try:
454 grid = cls(likelihood=None, priors=dictionary['priors'],
455 grid_size=dictionary['sample_points'],
456 label=dictionary['label'], outdir=dictionary['outdir'])
458 # set the likelihood
459 grid._ln_likelihood = dictionary['ln_likelihood']
460 grid.ln_noise_evidence = dictionary['ln_noise_evidence']
462 return grid
463 except TypeError as e:
464 raise IOError("Unable to load dictionary, error={}".format(e))
465 else:
466 raise IOError("No result '{}' found".format(filename))