Source code for ligo.skymap.bayestar.ez_emcee

# Copyright (C) 2018-2020  Leo Singer
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
import numpy as np
from tqdm import tqdm

from .ptemcee import Sampler

__all__ = ('ez_emcee',)


def logp(x, lo, hi):
    return np.where(((x >= lo) & (x <= hi)).all(-1), 0.0, -np.inf)


[docs] def ez_emcee(log_prob_fn, lo, hi, nindep=200, ntemps=10, nwalkers=None, nburnin=500, args=(), kwargs={}, **options): r'''Fire-and-forget MCMC sampling using `ptemcee.Sampler`, featuring automated convergence monitoring, progress tracking, and thinning. The parameters are bounded in the finite interval described by ``lo`` and ``hi`` (including ``-np.inf`` and ``np.inf`` for half-infinite or infinite domains). If run in an interactive terminal, live progress is shown including the current sample number, the total required number of samples, time elapsed and estimated time remaining, acceptance fraction, and autocorrelation length. Sampling terminates when all chains have accumulated the requested number of independent samples. Parameters ---------- log_prob_fn : callable The log probability function. It should take as its argument the parameter vector as an of length ``ndim``, or if it is vectorized, a 2D array with ``ndim`` columns. lo : list, `numpy.ndarray` List of lower limits of parameters, of length ``ndim``. hi : list, `numpy.ndarray` List of upper limits of parameters, of length ``ndim``. nindep : int, optional Minimum number of independent samples. ntemps : int, optional Number of temperatures. nwalkers : int, optional Number of walkers. The default is 4 times the number of dimensions. nburnin : int, optional Number of samples to discard during burn-in phase. Returns ------- chain : `numpy.ndarray` The thinned and flattened posterior sample chain, with at least ``nindep`` * ``nwalkers`` rows and exactly ``ndim`` columns. Other parameters ---------------- kwargs : Extra keyword arguments for `ptemcee.Sampler`. *Tip:* Consider setting the `pool` or `vectorized` keyword arguments in order to speed up likelihood evaluations. Notes ----- The autocorrelation length, which has a complexity of :math:`O(N \log N)` in the number of samples, is recalculated at geometrically progressing intervals so that its amortized complexity per sample is constant. (In simpler terms, as the chains grow longer and the autocorrelation length takes longer to compute, we update it less frequently so that it is never more expensive than sampling the chain in the first place.) Examples -------- >>> from ligo.skymap.bayestar.ez_emcee import ez_emcee >>> from matplotlib import pyplot as plt >>> import numpy as np >>> >>> def log_prob(params): ... """Eggbox function""" ... return 5 * np.log((2 + np.cos(0.5 * params).prod(-1))) ... >>> lo = [-3*np.pi, -3*np.pi] >>> hi = [+3*np.pi, +3*np.pi] >>> chain = ez_emcee(log_prob, lo, hi, vectorize=True) # doctest: +SKIP Sampling: 51%|██ | 8628/16820 [00:04<00:04, 1966.74it/s, accept=0.535, acl=62] >>> plt.plot(chain[:, 0], chain[:, 1], '.') # doctest: +SKIP .. image:: eggbox.png ''' # noqa: E501 lo = np.asarray(lo) hi = np.asarray(hi) ndim = len(lo) if nwalkers is None: nwalkers = 4 * ndim nsteps = 64 with tqdm(total=nburnin + nindep * nsteps) as progress: sampler = Sampler(nwalkers, ndim, log_prob_fn, logp, ntemps=ntemps, loglargs=args, loglkwargs=kwargs, logpargs=[lo, hi], random=np.random, **options) pos = np.random.uniform(lo, hi, (ntemps, nwalkers, ndim)) # Burn in progress.set_description('Burning in') for pos, _, _ in sampler.sample( pos, iterations=nburnin, storechain=False): progress.update() sampler.reset() acl = np.nan while not np.isfinite(acl) or sampler.time < nindep * acl: # Advance the chain progress.total = nburnin + max(sampler.time + nsteps, nindep * acl) progress.set_description('Sampling') for pos, _, _ in sampler.sample(pos, iterations=nsteps): progress.update() # Refresh convergence statistics progress.set_description('Checking') acl = sampler.get_autocorr_time()[0].max() if np.isfinite(acl): acl = max(1, int(np.ceil(acl))) accept = np.mean(sampler.acceptance_fraction[0]) progress.set_postfix(acl=acl, accept=accept) # The autocorrelation time calculation has complexity N log N in # the number of posterior samples. Only refresh the autocorrelation # length estimate on logarithmically spaced samples so that the # amortized complexity per sample is constant. nsteps *= 2 chain = sampler.chain[0, :, ::acl, :] s = chain.shape return chain.reshape((-1, s[-1]))