Coverage for bilby/core/utils/io.py: 70%

238 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import datetime 

2import inspect 

3import json 

4import os 

5import shutil 

6from importlib import import_module 

7from pathlib import Path 

8from datetime import timedelta 

9 

10import numpy as np 

11import pandas as pd 

12 

13from .log import logger 

14from .introspection import infer_args_from_method 

15 

16 

17def check_directory_exists_and_if_not_mkdir(directory): 

18 """ Checks if the given directory exists and creates it if it does not exist 

19 

20 Parameters 

21 ========== 

22 directory: str 

23 Name of the directory 

24 

25 """ 

26 Path(directory).mkdir(parents=True, exist_ok=True) 

27 

28 

29class BilbyJsonEncoder(json.JSONEncoder): 

30 def default(self, obj): 

31 from ..prior import MultivariateGaussianDist, Prior, PriorDict 

32 from ...gw.prior import HealPixMapPriorDist 

33 from ...bilby_mcmc.proposals import ProposalCycle 

34 

35 if isinstance(obj, np.integer): 

36 return int(obj) 

37 if isinstance(obj, np.floating): 

38 return float(obj) 

39 if isinstance(obj, PriorDict): 

40 return {"__prior_dict__": True, "content": obj._get_json_dict()} 

41 if isinstance(obj, (MultivariateGaussianDist, HealPixMapPriorDist, Prior)): 

42 return { 

43 "__prior__": True, 

44 "__module__": obj.__module__, 

45 "__name__": obj.__class__.__name__, 

46 "kwargs": dict(obj.get_instantiation_dict()), 

47 } 

48 if isinstance(obj, ProposalCycle): 

49 return str(obj) 

50 try: 

51 from astropy import cosmology as cosmo, units 

52 

53 if isinstance(obj, cosmo.FLRW): 

54 return encode_astropy_cosmology(obj) 

55 if isinstance(obj, units.Quantity): 

56 return encode_astropy_quantity(obj) 

57 if isinstance(obj, units.PrefixUnit): 

58 return str(obj) 

59 except ImportError: 

60 logger.debug("Cannot import astropy, cannot write cosmological priors") 

61 if isinstance(obj, np.ndarray): 

62 return {"__array__": True, "content": obj.tolist()} 

63 if isinstance(obj, complex): 

64 return {"__complex__": True, "real": obj.real, "imag": obj.imag} 

65 if isinstance(obj, pd.DataFrame): 

66 return {"__dataframe__": True, "content": obj.to_dict(orient="list")} 

67 if isinstance(obj, pd.Series): 

68 return {"__series__": True, "content": obj.to_dict()} 

69 if inspect.isfunction(obj): 

70 return { 

71 "__function__": True, 

72 "__module__": obj.__module__, 

73 "__name__": obj.__name__, 

74 } 

75 if inspect.isclass(obj): 

76 return { 

77 "__class__": True, 

78 "__module__": obj.__module__, 

79 "__name__": obj.__name__, 

80 } 

81 if isinstance(obj, (timedelta)): 

82 return { 

83 "__timedelta__": True, 

84 "__total_seconds__": obj.total_seconds() 

85 } 

86 return obj.isoformat() 

87 return json.JSONEncoder.default(self, obj) 

88 

89 

90def encode_astropy_cosmology(obj): 

91 cls_name = obj.__class__.__name__ 

92 dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)} 

93 dct["__cosmology__"] = True 

94 dct["__name__"] = cls_name 

95 return dct 

96 

97 

98def encode_astropy_quantity(dct): 

99 dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit)) 

100 if isinstance(dct["value"], np.ndarray): 

101 dct["value"] = list(dct["value"]) 

102 return dct 

103 

104 

105def decode_astropy_cosmology(dct): 

106 try: 

107 from astropy import cosmology as cosmo 

108 

109 cosmo_cls = getattr(cosmo, dct["__name__"]) 

110 del dct["__cosmology__"], dct["__name__"] 

111 return cosmo_cls(**dct) 

112 except ImportError: 

113 logger.debug( 

114 "Cannot import astropy, cosmological priors may not be " "properly loaded." 

115 ) 

116 return dct 

117 

118 

119def decode_astropy_quantity(dct): 

120 try: 

121 from astropy import units 

122 

123 if dct["value"] is None: 

124 return None 

125 else: 

126 del dct["__astropy_quantity__"] 

127 return units.Quantity(**dct) 

128 except ImportError: 

129 logger.debug( 

130 "Cannot import astropy, cosmological priors may not be " "properly loaded." 

131 ) 

132 return dct 

133 

134 

135def load_json(filename, gzip): 

136 if gzip or os.path.splitext(filename)[1].lstrip(".") == "gz": 

137 import gzip 

138 

139 with gzip.GzipFile(filename, "r") as file: 

140 json_str = file.read().decode("utf-8") 

141 dictionary = json.loads(json_str, object_hook=decode_bilby_json) 

142 else: 

143 with open(filename, "r") as file: 

144 dictionary = json.load(file, object_hook=decode_bilby_json) 

145 return dictionary 

146 

147 

148def decode_bilby_json(dct): 

149 if dct.get("__prior_dict__", False): 

150 cls = getattr(import_module(dct["__module__"]), dct["__name__"]) 

151 obj = cls._get_from_json_dict(dct) 

152 return obj 

153 if dct.get("__prior__", False): 

154 try: 

155 cls = getattr(import_module(dct["__module__"]), dct["__name__"]) 

156 except AttributeError: 

157 logger.warning( 

158 "Unknown prior class for parameter {}, defaulting to base Prior object".format( 

159 dct["kwargs"]["name"] 

160 ) 

161 ) 

162 from ..prior import Prior 

163 

164 for key in list(dct["kwargs"].keys()): 

165 if key not in ["name", "latex_label", "unit", "minimum", "maximum", "boundary"]: 

166 dct["kwargs"].pop(key) 

167 cls = Prior 

168 obj = cls(**dct["kwargs"]) 

169 return obj 

170 if dct.get("__cosmology__", False): 

171 return decode_astropy_cosmology(dct) 

172 if dct.get("__astropy_quantity__", False): 

173 return decode_astropy_quantity(dct) 

174 if dct.get("__array__", False): 

175 return np.asarray(dct["content"]) 

176 if dct.get("__complex__", False): 

177 return complex(dct["real"], dct["imag"]) 

178 if dct.get("__dataframe__", False): 

179 return pd.DataFrame(dct["content"]) 

180 if dct.get("__series__", False): 

181 return pd.Series(dct["content"]) 

182 if dct.get("__function__", False) or dct.get("__class__", False): 

183 default = ".".join([dct["__module__"], dct["__name__"]]) 

184 return getattr(import_module(dct["__module__"]), dct["__name__"], default) 

185 if dct.get("__timedelta__", False): 

186 return timedelta(seconds=dct["__total_seconds__"]) 

187 return dct 

188 

189 

190def recursively_decode_bilby_json(dct): 

191 """ 

192 Recursively call `bilby_decode_json` 

193 

194 Parameters 

195 ---------- 

196 dct: dict 

197 The dictionary to decode 

198 

199 Returns 

200 ------- 

201 dct: dict 

202 The original dictionary with all the elements decode if possible 

203 """ 

204 dct = decode_bilby_json(dct) 

205 if isinstance(dct, dict): 

206 for key in dct: 

207 if isinstance(dct[key], dict): 

208 dct[key] = recursively_decode_bilby_json(dct[key]) 

209 return dct 

210 

211 

212def decode_from_hdf5(item): 

213 """ 

214 Decode an item from HDF5 format to python type. 

215 

216 This currently just converts __none__ to None and some arrays to lists 

217 

218 .. versionadded:: 1.0.0 

219 

220 Parameters 

221 ---------- 

222 item: object 

223 Item to be decoded 

224 

225 Returns 

226 ------- 

227 output: object 

228 Converted input item 

229 """ 

230 if isinstance(item, str) and item == "__none__": 

231 output = None 

232 elif isinstance(item, bytes) and item == b"__none__": 

233 output = None 

234 elif isinstance(item, (bytes, bytearray)): 

235 output = item.decode() 

236 elif isinstance(item, np.ndarray): 

237 if item.size == 0: 

238 output = item 

239 elif "|S" in str(item.dtype) or isinstance(item[0], bytes): 

240 output = [it.decode() for it in item] 

241 else: 

242 output = item 

243 elif isinstance(item, np.bool_): 

244 output = bool(item) 

245 else: 

246 output = item 

247 return output 

248 

249 

250def encode_for_hdf5(key, item): 

251 """ 

252 Encode an item to a HDF5 saveable format. 

253 

254 .. versionadded:: 1.1.0 

255 

256 Parameters 

257 ---------- 

258 item: object 

259 Object to be encoded, specific options are provided for Bilby types 

260 

261 Returns 

262 ------- 

263 output: object 

264 Input item converted into HDF5 saveable format 

265 """ 

266 from ..prior.dict import PriorDict 

267 

268 if isinstance(item, np.int_): 

269 item = int(item) 

270 elif isinstance(item, np.float64): 

271 item = float(item) 

272 elif isinstance(item, np.complex128): 

273 item = complex(item) 

274 if isinstance(item, np.ndarray): 

275 # Numpy's wide unicode strings are not supported by hdf5 

276 if item.dtype.kind == 'U': 

277 logger.debug(f'converting dtype {item.dtype} for hdf5') 

278 item = np.array(item, dtype='S') 

279 if isinstance(item, (np.ndarray, int, float, complex, str, bytes)): 

280 output = item 

281 elif item is None: 

282 output = "__none__" 

283 elif isinstance(item, list): 

284 item_array = np.array(item) 

285 if len(item) == 0: 

286 output = item 

287 elif np.issubdtype(item_array.dtype, np.number): 

288 output = np.array(item) 

289 elif issubclass(item_array.dtype.type, str) or None in item: 

290 output = list() 

291 for value in item: 

292 if isinstance(value, str): 

293 output.append(value.encode("utf-8")) 

294 elif isinstance(value, bytes): 

295 output.append(value) 

296 elif value is None: 

297 output.append(b"__none__") 

298 else: 

299 output.append(str(value).encode("utf-8")) 

300 else: 

301 raise ValueError(f'Cannot save {key}: {type(item)} type') 

302 elif isinstance(item, PriorDict): 

303 output = json.dumps(item._get_json_dict()) 

304 elif isinstance(item, pd.DataFrame): 

305 output = item.to_dict(orient="list") 

306 elif inspect.isfunction(item) or inspect.isclass(item): 

307 output = dict( 

308 __module__=item.__module__, __name__=item.__name__, __class__=True 

309 ) 

310 elif isinstance(item, dict): 

311 output = item.copy() 

312 elif isinstance(item, tuple): 

313 output = {str(ii): elem for ii, elem in enumerate(item)} 

314 elif isinstance(item, datetime.timedelta): 

315 output = item.total_seconds() 

316 else: 

317 raise ValueError(f'Cannot save {key}: {type(item)} type') 

318 return output 

319 

320 

321def recursively_load_dict_contents_from_group(h5file, path): 

322 """ 

323 Recursively load a HDF5 file into a dictionary 

324 

325 .. versionadded:: 1.1.0 

326 

327 Parameters 

328 ---------- 

329 h5file: h5py.File 

330 Open h5py file object 

331 path: str 

332 Path within the HDF5 file 

333 

334 Returns 

335 ------- 

336 output: dict 

337 The contents of the HDF5 file unpacked into the dictionary. 

338 """ 

339 import h5py 

340 

341 output = dict() 

342 for key, item in h5file[path].items(): 

343 if isinstance(item, h5py.Dataset): 

344 output[key] = decode_from_hdf5(item[()]) 

345 elif isinstance(item, h5py.Group): 

346 output[key] = recursively_load_dict_contents_from_group( 

347 h5file, path + key + "/" 

348 ) 

349 return output 

350 

351 

352def recursively_save_dict_contents_to_group(h5file, path, dic): 

353 """ 

354 Recursively save a dictionary to a HDF5 group 

355 

356 .. versionadded:: 1.1.0 

357 

358 Parameters 

359 ---------- 

360 h5file: h5py.File 

361 Open HDF5 file 

362 path: str 

363 Path inside the HDF5 file 

364 dic: dict 

365 The dictionary containing the data 

366 """ 

367 for key, item in dic.items(): 

368 item = encode_for_hdf5(key, item) 

369 if isinstance(item, dict): 

370 recursively_save_dict_contents_to_group(h5file, path + key + "/", item) 

371 else: 

372 h5file[path + key] = item 

373 

374 

375def safe_file_dump(data, filename, module): 

376 """ Safely dump data to a .pickle file 

377 

378 Parameters 

379 ========== 

380 data: 

381 data to dump 

382 filename: str 

383 The file to dump to 

384 module: pickle, dill, str 

385 The python module to use. If a string, the module will be imported 

386 """ 

387 if isinstance(module, str): 

388 module = import_module(module) 

389 temp_filename = filename + ".temp" 

390 with open(temp_filename, "wb") as file: 

391 module.dump(data, file) 

392 shutil.move(temp_filename, filename) 

393 

394 

395def move_old_file(filename, overwrite=False): 

396 """ Moves or removes an old file. 

397 

398 Parameters 

399 ========== 

400 filename: str 

401 Name of the file to be move 

402 overwrite: bool, optional 

403 Whether or not to remove the file or to change the name 

404 to filename + '.old' 

405 """ 

406 if os.path.isfile(filename): 

407 if overwrite: 

408 logger.debug("Removing existing file {}".format(filename)) 

409 os.remove(filename) 

410 else: 

411 logger.debug( 

412 "Renaming existing file {} to {}.old".format(filename, filename) 

413 ) 

414 shutil.move(filename, filename + ".old") 

415 logger.debug("Saving result to {}".format(filename)) 

416 

417 

418def safe_save_figure(fig, filename, **kwargs): 

419 check_directory_exists_and_if_not_mkdir(os.path.dirname(filename)) 

420 from matplotlib import rcParams 

421 

422 try: 

423 fig.savefig(fname=filename, **kwargs) 

424 except RuntimeError: 

425 logger.debug("Failed to save plot with tex labels turning off tex.") 

426 rcParams["text.usetex"] = False 

427 fig.savefig(fname=filename, **kwargs)