Source code for ligo.skymap.io.hdf5

# Copyright (C) 2016-2022  Leo Singer, John Veitch
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
"""Read HDF5 posterior sample chain HDF5 files."""

import numpy as np
import h5py
from astropy.table import Column, Table

# Constants from lalinference module
POSTERIOR_SAMPLES = 'posterior_samples'
LINEAR = 0
CIRCULAR = 1
FIXED = 2
OUTPUT = 3

__all__ = ('read_samples', 'write_samples')


def _identity(x):
    return x


_colname_map = (('rightascension', 'ra', _identity),
                ('right_ascension', 'ra', _identity),
                ('declination', 'dec', _identity),
                ('logdistance', 'dist', np.exp),
                ('distance', 'dist', _identity),
                ('luminosity_distance', 'dist', _identity),
                ('polarisation', 'psi', _identity),
                ('chirpmass', 'mc', _identity),
                ('chirp_mass', 'mc', _identity),
                ('a_spin1', 'a1', _identity),
                ('a_1', 'a1', _identity),
                ('a_spin2', 'a2', _identity),
                ('a_2', 'a2', _identity),
                ('tilt_spin1', 'tilt1', _identity),
                ('tilt_1', 'tilt1', _identity),
                ('tilt_spin2', 'tilt2', _identity),
                ('tilt_2', 'tilt2', _identity),
                ('geocent_time', 'time', _identity))


def _remap_colnames(table):
    for old_name, new_name, func in _colname_map:
        if old_name in table.colnames:
            table[new_name] = func(table.columns.pop(old_name))


def _find_table(group, tablename):
    """Recursively search an HDF5 group or file for a dataset by name.

    Parameters
    ----------
    group : `h5py.File` or `h5py.Group`
        The file or group to search
    tablename : str
        The name of the table to search for

    Returns
    -------
    dataset : `h5py.Dataset`
        The dataset whose name is `tablename`

    Raises
    ------
    KeyError
        If the table is not found or if multiple matching tables are found

    Examples
    --------
    Check that we can find a file by name:

    >>> import os.path
    >>> import tempfile
    >>> table = Table(np.eye(3), names=['a', 'b', 'c'])
    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'bar')
    <HDF5 dataset "bar": shape (3,), type "|V24">

    Check that an exception is raised if the table is not found:

    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'plugh')
    Traceback (most recent call last):
        ...
    KeyError: 'Table not found: plugh'

    Check that an exception is raised if multiple tables are found:

    >>> with tempfile.TemporaryDirectory() as dir:
    ...     filename = os.path.join(dir, 'test.hdf5')
    ...     table.write(filename, path='foo/bar', append=True)
    ...     table.write(filename, path='foo/bat', append=True)
    ...     table.write(filename, path='foo/xyzzy/bat', append=True)
    ...     with h5py.File(filename, 'r') as f:
    ...         _find_table(f, 'bat')
    Traceback (most recent call last):
        ...
    KeyError: 'Multiple tables called bat exist: foo/bat, foo/xyzzy/bat'

    """
    results = {}

    def visitor(key, value):
        _, _, name = key.rpartition('/')
        if name == tablename:
            results[key] = value

    group.visititems(visitor)

    if len(results) == 0:
        raise KeyError('Table not found: {0}'.format(tablename))

    if len(results) > 1:
        raise KeyError('Multiple tables called {0} exist: {1}'.format(
            tablename, ', '.join(sorted(results.keys()))))

    table, = results.values()
    return table


[docs] def read_samples(filename, path=None, tablename=POSTERIOR_SAMPLES): """Read an HDF5 sample chain file. Parameters ---------- filename : str The path of the HDF5 file on the filesystem. path : str, optional The path of the dataset within the HDF5 file. tablename : str, optional The name of table to search for recursively within the HDF5 file. By default, search for 'posterior_samples'. Returns ------- chain : `astropy.table.Table` The sample chain as an Astropy table. Examples -------- Test reading a file written using the Python API: >>> import os.path >>> import tempfile >>> table = Table([ ... Column(np.ones(10), name='foo', meta={'vary': FIXED}), ... Column(np.arange(10), name='bar', meta={'vary': LINEAR}), ... Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}), ... Column(np.arange(10), name='baz', meta={'vary': OUTPUT}) ... ]) >>> with tempfile.TemporaryDirectory() as dir: ... filename = os.path.join(dir, 'test.hdf5') ... write_samples(table, filename, path='foo/bar/posterior_samples') ... len(read_samples(filename)) 10 Test reading a file that was written using the LAL HDF5 C API: >>> from importlib.resources import files >>> with files('ligo.skymap.io.tests.data').joinpath( ... 'test.hdf5').open('rb') as f: ... table = read_samples(f) >>> table.colnames ['uvw', 'opq', 'lmn', 'ijk', 'def', 'abc', 'ghi', 'rst'] """ with h5py.File(filename, 'r') as f: if path is not None: # Look for a given path table = f[path] else: # Look for a given table name table = _find_table(f, tablename) table = Table.read(table) # Restore vary types. for i, column in enumerate(table.columns.values()): column.meta['vary'] = table.meta.get( 'FIELD_{0}_VARY'.format(i), OUTPUT) # Restore fixed columns from table attributes. for key, value in table.meta.items(): # Skip attributes from H5TB interface # (https://www.hdfgroup.org/HDF5/doc/HL/H5TB_Spec.html). if key == 'CLASS' or key == 'VERSION' or key == 'TITLE' or \ key.startswith('FIELD_'): continue table.add_column(Column([value] * len(table), name=key, meta={'vary': FIXED})) # Delete remaining table attributes. table.meta.clear() # Normalize column names. _remap_colnames(table) # Done! return table
[docs] def write_samples(table, filename, metadata=None, **kwargs): """Write an HDF5 sample chain file. Parameters ---------- table : `astropy.table.Table` The sample chain as an Astropy table. filename : str The path of the HDF5 file on the filesystem. metadata: dict (optional) Dictionary of (path, value) pairs of metadata attributes to add to the output file kwargs: dict Any keyword arguments for `astropy.table.Table.write`. Examples -------- Check that we catch columns that are supposed to be FIXED but are not: >>> table = Table([ ... Column(np.arange(10), name='foo', meta={'vary': FIXED}) ... ]) >>> write_samples(table, 'bar.hdf5', 'bat/baz') Traceback (most recent call last): ... AssertionError: Arrays are not equal Column foo is a fixed column, but its values are not identical ... And now try writing an arbitrary example to a temporary file. >>> import os.path >>> import tempfile >>> table = Table([ ... Column(np.ones(10), name='foo', meta={'vary': FIXED}), ... Column(np.arange(10), name='bar', meta={'vary': LINEAR}), ... Column(np.arange(10) * np.pi, name='bat', meta={'vary': CIRCULAR}), ... Column(np.arange(10), name='baz', meta={'vary': OUTPUT}), ... Column(np.ones(10), name='plugh'), ... Column(np.arange(10), name='xyzzy') ... ]) >>> with tempfile.TemporaryDirectory() as dir: ... write_samples( ... table, os.path.join(dir, 'test.hdf5'), path='bat/baz', ... metadata={'bat/baz': {'widget': 'shoephone'}}) """ # noqa: W291 # Copy the table so that we do not modify the original. table = table.copy() # Make sure that all tables have a 'vary' type. for column in table.columns.values(): if 'vary' not in column.meta: if np.all(column[0] == column[1:]): column.meta['vary'] = FIXED else: column.meta['vary'] = OUTPUT # Reconstruct table attributes. for colname, column in tuple(table.columns.items()): if column.meta['vary'] == FIXED: np.testing.assert_array_equal(column[1:], column[0], 'Column {0} is a fixed column, but ' 'its values are not identical' .format(column.name)) table.meta[colname] = column[0] del table[colname] for i, column in enumerate(table.columns.values()): table.meta['FIELD_{0}_VARY'.format(i)] = column.meta.pop('vary') table.write(filename, format='hdf5', **kwargs) if metadata: with h5py.File(filename, 'r+') as hdf: for internal_path, attributes in metadata.items(): for key, value in attributes.items(): try: hdf[internal_path].attrs[key] = value except KeyError: raise KeyError( 'Unable to set metadata {0}[{1}] = {2}'.format( internal_path, key, value))