Loading [MathJax]/extensions/TeX/AMSsymbols.js
LALInference 4.1.9.1-00ddc7f
All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Macros Modules Pages
hdf5.py
Go to the documentation of this file.
1# Copyright (C) 2016 Leo Singer, John Veitch
2#
3# This program is free software; you can redistribute it and/or modify it
4# under the terms of the GNU General Public License as published by the
5# Free Software Foundation; either version 2 of the License, or (at your
6# option) any later version.
7#
8# This program is distributed in the hope that it will be useful, but
9# WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
11# Public License for more details.
12#
13# You should have received a copy of the GNU General Public License along
14# with this program; if not, write to the Free Software Foundation, Inc.,
15# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
16#
17"""
18Reading HDF5 posterior sample chain HDF5 files.
19"""
20
21import numpy as np
22import h5py
23from astropy.table import Column, Table
24from lalinference import LALInferenceHDF5PosteriorSamplesDatasetName \
25 as POSTERIOR_SAMPLES
26from lalinference import LALINFERENCE_PARAM_FIXED as FIXED
27from lalinference import LALINFERENCE_PARAM_OUTPUT as OUTPUT
28
29__all__ = ('read_samples', 'write_samples', 'extract_metadata')
30
31
32def _identity(x):
33 return x
34
35
36_colname_map = (('rightascension', 'ra', _identity),
37 ('declination', 'dec', _identity),
38 ('logdistance', 'dist', np.exp),
39 ('distance', 'dist', _identity),
40 ('polarisation', 'psi', _identity),
41 ('chirpmass', 'mc', _identity),
42 ('a_spin1', 'a1', _identity),
43 ('a_spin2', 'a2', _identity),
44 ('tilt_spin1', 'tilt1', _identity),
45 ('tilt_spin2', 'tilt2', _identity))
46
47
48def _remap_colnames(table):
49 for old_name, new_name, func in _colname_map:
50 if old_name in table.colnames:
51 table[new_name] = func(table.columns.pop(old_name))
52
53
54def _find_table(group, tablename):
55 """Recursively search an HDF5 group or file for a dataset by name.
56
57 Parameters
58 ----------
59 group : `h5py.File` or `h5py.Group`
60 The file or group to search
61 tablename : str
62 The name of the table to search for
63
64 Returns
65 -------
66 table : `h5py.Dataset`
67 The dataset whose name is `tablename`
68
69 Raises
70 ------
71 KeyError
72 If the table is not found or if multiple matching tables are found
73
74 Check that we can find a file by name:
75 >>> import os.path
76 >>> from tempfile import TemporaryDirectory
77 >>> table = Table(np.eye(3), names=['a', 'b', 'c'])
78 >>> with TemporaryDirectory() as dir:
79 ... filename = os.path.join(dir, 'test.hdf5')
80 ... table.write(filename, path='foo/bar', append=True)
81 ... table.write(filename, path='foo/bat', append=True)
82 ... table.write(filename, path='foo/xyzzy/bat', append=True)
83 ... with h5py.File(filename, 'r') as f:
84 ... _find_table(f, 'bar')
85 <HDF5 dataset "bar": shape (3,), type "|V24">
86
87 Check that an exception is raised if the table is not found:
88 >>> with TemporaryDirectory() as dir:
89 ... filename = os.path.join(dir, 'test.hdf5')
90 ... table.write(filename, path='foo/bar', append=True)
91 ... table.write(filename, path='foo/bat', append=True)
92 ... table.write(filename, path='foo/xyzzy/bat', append=True)
93 ... with h5py.File(filename, 'r') as f:
94 ... _find_table(f, 'plugh')
95 Traceback (most recent call last):
96 ...
97 KeyError: 'Table not found: plugh'
98
99 Check that an exception is raised if multiple tables are found:
100 >>> with TemporaryDirectory() as dir:
101 ... filename = os.path.join(dir, 'test.hdf5')
102 ... table.write(filename, path='foo/bar', append=True)
103 ... table.write(filename, path='foo/bat', append=True)
104 ... table.write(filename, path='foo/xyzzy/bat', append=True)
105 ... with h5py.File(filename, 'r') as f:
106 ... _find_table(f, 'bat')
107 Traceback (most recent call last):
108 ...
109 KeyError: 'Multiple tables called bat exist: foo/bat, foo/xyzzy/bat'
110 """
111 results = {}
112
113 def visitor(key, value):
114 _, _, name = key.rpartition('/')
115 if name == tablename:
116 results[key] = value
117
118 group.visititems(visitor)
119
120 if len(results) == 0:
121 raise KeyError('Table not found: {0}'.format(tablename))
122
123 if len(results) > 1:
124 raise KeyError('Multiple tables called {0} exist: {1}'.format(
125 tablename, ', '.join(sorted(results.keys()))))
126
127 table, = results.values()
128 return table
129
130
131def read_samples(filename, path=None, tablename=POSTERIOR_SAMPLES):
132 """Read an HDF5 sample chain file.
133
134 Parameters
135 ----------
136 filename : str
137 The path of the HDF5 file on the filesystem.
138 path : str, optional
139 The path of the dataset within the HDF5 file.
140 tablename : str, optional
141 The name of table to search for recursively within the HDF5 file.
142 By default, search for 'posterior_samples'.
143
144 Returns
145 -------
146 table : `astropy.table.Table`
147 The sample chain as an Astropy table.
148
149 Test reading a file written using the Python API:
150 >>> import os.path
151 >>> from tempfile import TemporaryDirectory
152 >>> table = Table([
153 ... Column(np.ones(10), name='foo', meta={'vary': FIXED}),
154 ... Column(np.arange(10), name='bar', meta={'vary': LINEAR}),
155 ... Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}),
156 ... Column(np.arange(10), name='baz', meta={'vary': OUTPUT})
157 ... ])
158 >>> with TemporaryDirectory() as dir:
159 ... filename = os.path.join(dir, 'test.hdf5')
160 ... write_samples(table, filename, path='foo/bar/posterior_samples')
161 ... len(read_samples(filename))
162 10
163
164 Test reading a file that was written using the LAL HDF5 C API:
165 >>> table = read_samples('test.hdf5')
166 >>> table.colnames
167 ['uvw', 'opq', 'lmn', 'ijk', 'def', 'abc', 'rst', 'ghi']
168 """
169 with h5py.File(filename, 'r') as f:
170 if path is not None: # Look for a given path
171 table = f[path]
172 else: # Look for a given table name
173 table = _find_table(f, tablename)
174 table = Table.read(table)
175
176 # Restore vary types.
177 for i, column in enumerate(table.columns.values()):
178 column.meta['vary'] = table.meta['FIELD_{0}_VARY'.format(i)]
179
180 # Restore fixed columns from table attributes.
181 for key, value in table.meta.items():
182 # Skip attributes from H5TB interface
183 # (https://www.hdfgroup.org/HDF5/doc/HL/H5TB_Spec.html).
184 if key == 'CLASS' or key == 'VERSION' or key == 'TITLE' or key.startswith('FIELD_'):
185 continue
186 if key in table.colnames:
187 # This is handled separately as rename_duplicate can trigger a bug in astropy < 2.0.16
188 table.add_column(Column([value] * len(table), name=key,
189 meta={'vary': FIXED}), rename_duplicate=True)
190 else:
191 table.add_column(Column([value] * len(table), name=key,
192 meta={'vary': FIXED}))
193
194 # Delete remaining table attributes.
195 table.meta.clear()
196
197 # Normalize column names.
198 _remap_colnames(table)
199
200 # Done!
201 return table
202
203
204def write_samples(table, filename, metadata=None, **kwargs):
205 """Write an HDF5 sample chain file.
206
207 Parameters
208 ----------
209 table : `astropy.table.Table`
210 The sample chain as an Astropy table.
211 filename : str
212 The path of the HDF5 file on the filesystem.
213 metadata: dict (optional)
214 Dictionary of (path, value) pairs of metadata attributes
215 to add to the output file
216 kwargs: dict
217 Any keyword arguments for `astropy.table.Table.write`.
218
219 Check that we catch columns that are supposed to be FIXED but are not:
220 >>> table = Table([
221 ... Column(np.arange(10), name='foo', meta={'vary': FIXED})
222 ... ])
223 >>> write_samples(table, 'bar.hdf5', 'bat/baz') # doctest: +ELLIPSIS
224 Traceback (most recent call last):
225 ...
226 AssertionError:
227 Arrays are not equal
228 Column foo is a fixed column, but its values are not identical
229 ...
230
231 And now try writing an arbitrary example to a temporary file.
232 >>> import os.path
233 >>> from tempfile import TemporaryDirectory
234 >>> table = Table([
235 ... Column(np.ones(10), name='foo', meta={'vary': FIXED}),
236 ... Column(np.arange(10), name='bar', meta={'vary': LINEAR}),
237 ... Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}),
238 ... Column(np.arange(10), name='baz', meta={'vary': OUTPUT})
239 ... ])
240 >>> with TemporaryDirectory() as dir:
241 ... write_samples(
242 ... table, os.path.join(dir, 'test.hdf5'), path='bat/baz')
243 """
244 # Copy the table so that we do not modify the original.
245 table = table.copy()
246
247 # Make sure that all tables have a 'vary' type.
248 for column in table.columns.values():
249 if 'vary' not in column.meta:
250 if np.all(column[0] == column[1:]):
251 column.meta['vary'] = FIXED
252 else:
253 column.meta['vary'] = OUTPUT
254 # Reconstruct table attributes.
255 for colname, column in tuple(table.columns.items()):
256 if column.meta['vary'] == FIXED:
257 np.testing.assert_array_equal(column[1:], column[0],
258 'Column {0} is a fixed column, but '
259 'its values are not identical'
260 .format(column.name))
261 table.meta[colname] = column[0]
262 del table[colname]
263 for i, column in enumerate(table.columns.values()):
264 table.meta['FIELD_{0}_VARY'.format(i)] = column.meta['vary']
265 table.write(filename, format='hdf5', **kwargs)
266 if metadata:
267 with h5py.File(filename, 'a') as hdf:
268 for internal_path, attributes in metadata.items():
269 for key, value in attributes.items():
270 try:
271 hdf[internal_path].attrs[key] = value
272 except KeyError:
273 raise KeyError(
274 'Unable to set metadata {0}[{1}] = {2}'.format(
275 internal_path, key, value))
276
277def update_metadata(metadata, level, attrs, strict_versions, collision='raise'):
278 """Updates the sub-dictionary 'key' of 'metadata' with the values from
279 'attrs', while enforcing that existing values are equal to those with
280 which the dict is updated.
281 """
282 if level not in metadata:
283 metadata[level] = {}
284 for key in attrs:
285 if key in metadata[level]:
286 if collision == 'raise':
287 if attrs[key]!=metadata[level][key]:
288 if key == 'version' and not strict_versions:
289 continue
290 else:
291 raise ValueError(
292 'Metadata mismatch on level %r for key %r:\n\t%r != %r'
293 % (level, key, attrs[key], metadata[level][key]))
294 elif collision == 'append':
295 if isinstance(metadata[level][key], list):
296 metadata[level][key].append(attrs[key])
297 else:
298 metadata[level][key] = [metadata[level][key], attrs[key]]
299 elif collision == 'ignore':
300 pass
301 else:
302 raise ValueError('Invalid value for collision: %r' % collision)
303 else:
304 metadata[level][key] = attrs[key]
305 return
306
307def extract_metadata(filename, metadata, log_noise_evidences=[], log_max_likelihoods=[], nlive=[], dset_name=None, nest=False, strict_versions=True):
308 """
309 Extract metadata from HDF5 sample chain file
310
311 Parameters
312 ----------
313 filename : str
314 The path of the HDF5 file on the filesystem.
315 metadata : dict
316 Dict into which to place metadata
317 log_noise_evidences : array (optional)
318 Array into which to place log noise evidences (if nest = True)
319 log_max_likelihoods : array (optional)
320 Array into which to place log max likelihoods (if nest = True)
321 nlive : array (optional)
322 Array into which to place number of live points (if nest = True)
323 return_run_identifier : Boolean (optional : default False)
324 Whether to return the run identifier
325 nest : Boolean (optional : default False)
326 Whether to output quantities that only exist for nest runs
327
328 Returns
329 -------
330 run_identifier : str
331 The run identifier
332 """
333 with h5py.File(filename, 'r') as hdf:
334 # walk down the groups until the actual data is reached, storing
335 # metadata for each step.
336 current_level = '/lalinference'
337 group = hdf[current_level]
338 update_metadata(metadata, current_level, group.attrs, strict_versions)
339
340 if len(hdf[current_level].keys()) != 1:
341 raise KeyError('Multiple run-identifiers found: %r'
342 % list(hdf[current_level].keys()))
343 # we ensured above that there is only one identifier in the group.
344 run_identifier = list(hdf[current_level].keys())[0]
345
346 current_level = '/lalinference/' + run_identifier
347 group = hdf[current_level]
348 update_metadata(metadata, current_level, group.attrs, strict_versions, collision='append')
349
350 if nest:
351 # store the noise evidence and max likelihood seperately for later use
352 log_noise_evidences.append(group.attrs['log_noise_evidence'])
353 log_max_likelihoods.append(group.attrs['log_max_likelihood'])
354 nlive.append(group.attrs['number_live_points'])
355
356 # storing the metadata under the posterior_group name simplifies
357 # writing it into the output hdf file.
358 if dset_name == None:
359 dset_name = POSTERIOR_SAMPLES
360 current_level = '/lalinference/' + run_identifier + '/' + dset_name
361 current_level_posterior = '/lalinference/' + run_identifier + '/' + POSTERIOR_SAMPLES
362 group = hdf[current_level]
363 update_metadata(metadata, current_level_posterior, group.attrs, strict_versions, collision='ignore')
364
365 return run_identifier
def extract_metadata(filename, metadata, log_noise_evidences=[], log_max_likelihoods=[], nlive=[], dset_name=None, nest=False, strict_versions=True)
Extract metadata from HDF5 sample chain file.
Definition: hdf5.py:332
def write_samples(table, filename, metadata=None, **kwargs)
Write an HDF5 sample chain file.
Definition: hdf5.py:243
def update_metadata(metadata, level, attrs, strict_versions, collision='raise')
Updates the sub-dictionary 'key' of 'metadata' with the values from 'attrs', while enforcing that exi...
Definition: hdf5.py:281
def read_samples(filename, path=None, tablename=POSTERIOR_SAMPLES)
Read an HDF5 sample chain file.
Definition: hdf5.py:168