Coverage for bilby/core/utils/io.py: 70%
238 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import inspect
3import json
4import os
5import shutil
6from importlib import import_module
7from pathlib import Path
8from datetime import timedelta
10import numpy as np
11import pandas as pd
13from .log import logger
14from .introspection import infer_args_from_method
17def check_directory_exists_and_if_not_mkdir(directory):
18 """ Checks if the given directory exists and creates it if it does not exist
20 Parameters
21 ==========
22 directory: str
23 Name of the directory
25 """
26 Path(directory).mkdir(parents=True, exist_ok=True)
29class BilbyJsonEncoder(json.JSONEncoder):
30 def default(self, obj):
31 from ..prior import MultivariateGaussianDist, Prior, PriorDict
32 from ...gw.prior import HealPixMapPriorDist
33 from ...bilby_mcmc.proposals import ProposalCycle
35 if isinstance(obj, np.integer):
36 return int(obj)
37 if isinstance(obj, np.floating):
38 return float(obj)
39 if isinstance(obj, PriorDict):
40 return {"__prior_dict__": True, "content": obj._get_json_dict()}
41 if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)):
42 return {
43 "__prior__": True,
44 "__module__": obj.__module__,
45 "__name__": obj.__class__.__name__,
46 "kwargs": dict(obj.get_instantiation_dict()),
47 }
48 if isinstance(obj, ProposalCycle):
49 return str(obj)
50 try:
51 from astropy import cosmology as cosmo, units
53 if isinstance(obj, cosmo.FLRW):
54 return encode_astropy_cosmology(obj)
55 if isinstance(obj, units.Quantity):
56 return encode_astropy_quantity(obj)
57 if isinstance(obj, units.PrefixUnit):
58 return str(obj)
59 except ImportError:
60 logger.debug("Cannot import astropy, cannot write cosmological priors")
61 if isinstance(obj, np.ndarray):
62 return {"__array__": True, "content": obj.tolist()}
63 if isinstance(obj, complex):
64 return {"__complex__": True, "real": obj.real, "imag": obj.imag}
65 if isinstance(obj, pd.DataFrame):
66 return {"__dataframe__": True, "content": obj.to_dict(orient="list")}
67 if isinstance(obj, pd.Series):
68 return {"__series__": True, "content": obj.to_dict()}
69 if inspect.isfunction(obj):
70 return {
71 "__function__": True,
72 "__module__": obj.__module__,
73 "__name__": obj.__name__,
74 }
75 if inspect.isclass(obj):
76 return {
77 "__class__": True,
78 "__module__": obj.__module__,
79 "__name__": obj.__name__,
80 }
81 if isinstance(obj, (timedelta)):
82 return {
83 "__timedelta__": True,
84 "__total_seconds__": obj.total_seconds()
85 }
86 return obj.isoformat()
87 return json.JSONEncoder.default(self, obj)
90def encode_astropy_cosmology(obj):
91 cls_name = obj.__class__.__name__
92 dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)}
93 dct["__cosmology__"] = True
94 dct["__name__"] = cls_name
95 return dct
98def encode_astropy_quantity(dct):
99 dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
100 if isinstance(dct["value"], np.ndarray):
101 dct["value"] = list(dct["value"])
102 return dct
105def decode_astropy_cosmology(dct):
106 try:
107 from astropy import cosmology as cosmo
109 cosmo_cls = getattr(cosmo, dct["__name__"])
110 del dct["__cosmology__"], dct["__name__"]
111 return cosmo_cls(**dct)
112 except ImportError:
113 logger.debug(
114 "Cannot import astropy, cosmological priors may not be " "properly loaded."
115 )
116 return dct
119def decode_astropy_quantity(dct):
120 try:
121 from astropy import units
123 if dct["value"] is None:
124 return None
125 else:
126 del dct["__astropy_quantity__"]
127 return units.Quantity(**dct)
128 except ImportError:
129 logger.debug(
130 "Cannot import astropy, cosmological priors may not be " "properly loaded."
131 )
132 return dct
135def load_json(filename, gzip):
136 if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz":
137 import gzip
139 with gzip.GzipFile(filename, "r") as file:
140 json_str = file.read().decode("utf-8")
141 dictionary = json.loads(json_str, object_hook=decode_bilby_json)
142 else:
143 with open(filename, "r") as file:
144 dictionary = json.load(file, object_hook=decode_bilby_json)
145 return dictionary
148def decode_bilby_json(dct):
149 if dct.get("__prior_dict__", False):
150 cls = getattr(import_module(dct["__module__"]), dct["__name__"])
151 obj = cls._get_from_json_dict(dct)
152 return obj
153 if dct.get("__prior__", False):
154 try:
155 cls = getattr(import_module(dct["__module__"]), dct["__name__"])
156 except AttributeError:
157 logger.warning(
158 "Unknown prior class for parameter {}, defaulting to base Prior object".format(
159 dct["kwargs"]["name"]
160 )
161 )
162 from ..prior import Prior
164 for key in list(dct["kwargs"].keys()):
165 if key not in ["name", "latex_label", "unit", "minimum", "maximum", "boundary"]:
166 dct["kwargs"].pop(key)
167 cls = Prior
168 obj = cls(**dct["kwargs"])
169 return obj
170 if dct.get("__cosmology__", False):
171 return decode_astropy_cosmology(dct)
172 if dct.get("__astropy_quantity__", False):
173 return decode_astropy_quantity(dct)
174 if dct.get("__array__", False):
175 return np.asarray(dct["content"])
176 if dct.get("__complex__", False):
177 return complex(dct["real"], dct["imag"])
178 if dct.get("__dataframe__", False):
179 return pd.DataFrame(dct["content"])
180 if dct.get("__series__", False):
181 return pd.Series(dct["content"])
182 if dct.get("__function__", False) or dct.get("__class__", False):
183 default = ".".join([dct["__module__"], dct["__name__"]])
184 return getattr(import_module(dct["__module__"]), dct["__name__"], default)
185 if dct.get("__timedelta__", False):
186 return timedelta(seconds=dct["__total_seconds__"])
187 return dct
190def recursively_decode_bilby_json(dct):
191 """
192 Recursively call `bilby_decode_json`
194 Parameters
195 ----------
196 dct: dict
197 The dictionary to decode
199 Returns
200 -------
201 dct: dict
202 The original dictionary with all the elements decode if possible
203 """
204 dct = decode_bilby_json(dct)
205 if isinstance(dct, dict):
206 for key in dct:
207 if isinstance(dct[key], dict):
208 dct[key] = recursively_decode_bilby_json(dct[key])
209 return dct
212def decode_from_hdf5(item):
213 """
214 Decode an item from HDF5 format to python type.
216 This currently just converts __none__ to None and some arrays to lists
218 .. versionadded:: 1.0.0
220 Parameters
221 ----------
222 item: object
223 Item to be decoded
225 Returns
226 -------
227 output: object
228 Converted input item
229 """
230 if isinstance(item, str) and item == "__none__":
231 output = None
232 elif isinstance(item, bytes) and item == b"__none__":
233 output = None
234 elif isinstance(item, (bytes, bytearray)):
235 output = item.decode()
236 elif isinstance(item, np.ndarray):
237 if item.size == 0:
238 output = item
239 elif "|S" in str(item.dtype) or isinstance(item[0], bytes):
240 output = [it.decode() for it in item]
241 else:
242 output = item
243 elif isinstance(item, np.bool_):
244 output = bool(item)
245 else:
246 output = item
247 return output
250def encode_for_hdf5(key, item):
251 """
252 Encode an item to a HDF5 saveable format.
254 .. versionadded:: 1.1.0
256 Parameters
257 ----------
258 item: object
259 Object to be encoded, specific options are provided for Bilby types
261 Returns
262 -------
263 output: object
264 Input item converted into HDF5 saveable format
265 """
266 from ..prior.dict import PriorDict
268 if isinstance(item, np.int_):
269 item = int(item)
270 elif isinstance(item, np.float64):
271 item = float(item)
272 elif isinstance(item, np.complex128):
273 item = complex(item)
274 if isinstance(item, np.ndarray):
275 # Numpy's wide unicode strings are not supported by hdf5
276 if item.dtype.kind == 'U':
277 logger.debug(f'converting dtype {item.dtype} for hdf5')
278 item = np.array(item, dtype='S')
279 if isinstance(item, (np.ndarray, int, float, complex, str, bytes)):
280 output = item
281 elif item is None:
282 output = "__none__"
283 elif isinstance(item, list):
284 item_array = np.array(item)
285 if len(item) == 0:
286 output = item
287 elif np.issubdtype(item_array.dtype, np.number):
288 output = np.array(item)
289 elif issubclass(item_array.dtype.type, str) or None in item:
290 output = list()
291 for value in item:
292 if isinstance(value, str):
293 output.append(value.encode("utf-8"))
294 elif isinstance(value, bytes):
295 output.append(value)
296 elif value is None:
297 output.append(b"__none__")
298 else:
299 output.append(str(value).encode("utf-8"))
300 else:
301 raise ValueError(f'Cannot save {key}: {type(item)} type')
302 elif isinstance(item, PriorDict):
303 output = json.dumps(item._get_json_dict())
304 elif isinstance(item, pd.DataFrame):
305 output = item.to_dict(orient="list")
306 elif inspect.isfunction(item) or inspect.isclass(item):
307 output = dict(
308 __module__=item.__module__, __name__=item.__name__, __class__=True
309 )
310 elif isinstance(item, dict):
311 output = item.copy()
312 elif isinstance(item, tuple):
313 output = {str(ii): elem for ii, elem in enumerate(item)}
314 elif isinstance(item, datetime.timedelta):
315 output = item.total_seconds()
316 else:
317 raise ValueError(f'Cannot save {key}: {type(item)} type')
318 return output
321def recursively_load_dict_contents_from_group(h5file, path):
322 """
323 Recursively load a HDF5 file into a dictionary
325 .. versionadded:: 1.1.0
327 Parameters
328 ----------
329 h5file: h5py.File
330 Open h5py file object
331 path: str
332 Path within the HDF5 file
334 Returns
335 -------
336 output: dict
337 The contents of the HDF5 file unpacked into the dictionary.
338 """
339 import h5py
341 output = dict()
342 for key, item in h5file[path].items():
343 if isinstance(item, h5py.Dataset):
344 output[key] = decode_from_hdf5(item[()])
345 elif isinstance(item, h5py.Group):
346 output[key] = recursively_load_dict_contents_from_group(
347 h5file, path + key + "/"
348 )
349 return output
352def recursively_save_dict_contents_to_group(h5file, path, dic):
353 """
354 Recursively save a dictionary to a HDF5 group
356 .. versionadded:: 1.1.0
358 Parameters
359 ----------
360 h5file: h5py.File
361 Open HDF5 file
362 path: str
363 Path inside the HDF5 file
364 dic: dict
365 The dictionary containing the data
366 """
367 for key, item in dic.items():
368 item = encode_for_hdf5(key, item)
369 if isinstance(item, dict):
370 recursively_save_dict_contents_to_group(h5file, path + key + "/", item)
371 else:
372 h5file[path + key] = item
375def safe_file_dump(data, filename, module):
376 """ Safely dump data to a .pickle file
378 Parameters
379 ==========
380 data:
381 data to dump
382 filename: str
383 The file to dump to
384 module: pickle, dill, str
385 The python module to use. If a string, the module will be imported
386 """
387 if isinstance(module, str):
388 module = import_module(module)
389 temp_filename = filename + ".temp"
390 with open(temp_filename, "wb") as file:
391 module.dump(data, file)
392 shutil.move(temp_filename, filename)
395def move_old_file(filename, overwrite=False):
396 """ Moves or removes an old file.
398 Parameters
399 ==========
400 filename: str
401 Name of the file to be move
402 overwrite: bool, optional
403 Whether or not to remove the file or to change the name
404 to filename + '.old'
405 """
406 if os.path.isfile(filename):
407 if overwrite:
408 logger.debug("Removing existing file {}".format(filename))
409 os.remove(filename)
410 else:
411 logger.debug(
412 "Renaming existing file {} to {}.old".format(filename, filename)
413 )
414 shutil.move(filename, filename + ".old")
415 logger.debug("Saving result to {}".format(filename))
418def safe_save_figure(fig, filename, **kwargs):
419 check_directory_exists_and_if_not_mkdir(os.path.dirname(filename))
420 from matplotlib import rcParams
422 try:
423 fig.savefig(fname=filename, **kwargs)
424 except RuntimeError:
425 logger.debug("Failed to save plot with tex labels turning off tex.")
426 rcParams["text.usetex"] = False
427 fig.savefig(fname=filename, **kwargs)