Coverage for bilby/core/sampler/dnest4.py: 60%
111 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import time
4import numpy as np
6from ..utils import logger
7from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
10class _DNest4Model(object):
11 def __init__(
12 self, log_likelihood_func, from_prior_func, widths, centers, highs, lows
13 ):
14 """Initialize the DNest4 model.
15 Args:
16 log_likelihood_func: function
17 The loglikelihood function to use during the Nested Sampling run.
18 from_prior_func: function
19 The function to use when randomly selecting parameter vectors from the prior space.
20 widths: array_like
21 The approximate widths of the prior distrbutions.
22 centers: array_like
23 The approximate center points of the prior distributions.
24 """
25 self._log_likelihood = log_likelihood_func
26 self._from_prior = from_prior_func
27 self._widths = widths
28 self._centers = centers
29 self._highs = highs
30 self._lows = lows
31 self._n_dim = len(widths)
32 return
34 def log_likelihood(self, coords):
35 """The model's log_likelihood function"""
36 return self._log_likelihood(coords)
38 def from_prior(self):
39 """The model's function to select random points from the prior space."""
40 return self._from_prior()
42 def perturb(self, coords):
43 """The perturb function to perform Monte Carlo trial moves."""
44 from ..utils.random import rng
46 idx = rng.integers(self._n_dim)
48 coords[idx] += self._widths[idx] * (rng.uniform(size=1) - 0.5)
49 cw = self._widths[idx]
50 cc = self._centers[idx]
52 coords[idx] = self.wrap(coords[idx], (cc - 0.5 * cw), cc + 0.5 * cw)
54 return 0.0
56 @staticmethod
57 def wrap(x, minimum, maximum):
58 if maximum <= minimum:
59 raise ValueError(
60 f"maximum {maximum} <= minimum {minimum}, when trying to wrap coordinates"
61 )
62 return (x - minimum) % (maximum - minimum) + minimum
65class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
67 """
68 Bilby wrapper of DNest4
70 Parameters
71 ==========
72 TBD
74 Other Parameters
75 ------==========
76 num_particles: int
77 The number of points to use in the Nested Sampling active population.
78 max_num_levels: int
79 The max number of diffusive likelihood levels that DNest4 should initialize
80 during the Diffusive Nested Sampling run.
81 backend: str
82 The python DNest4 backend for storing the output.
83 Options are: 'memory' and 'csv'. If 'memory' the
84 DNest4 outputs are stored in memory during the run. If 'csv' the
85 DNest4 outputs are written out to files with a CSV format during
86 the run.
87 CSV backend may not be functional right now (October 2020)
88 num_steps: int
89 The number of MCMC iterations to run
90 new_level_interval: int
91 The number of moves to run before creating a new diffusive likelihood level
92 lam: float
93 Set the backtracking scale length
94 beta: float
95 Set the strength of effect to force the histogram to equal bin counts
96 seed: int
97 Set the seed for the C++ random number generator
98 verbose: Bool
99 If True, prints information during run
100 """
102 sampler_name = "d4nest"
103 default_kwargs = dict(
104 max_num_levels=20,
105 num_steps=500,
106 new_level_interval=10000,
107 num_per_step=10000,
108 thread_steps=1,
109 num_particles=1000,
110 lam=10.0,
111 beta=100,
112 seed=None,
113 verbose=True,
114 outputfiles_basename=None,
115 backend="memory",
116 )
117 short_name = "dn4"
118 hard_exit = True
119 sampling_seed_key = "seed"
121 def __init__(
122 self,
123 likelihood,
124 priors,
125 outdir="outdir",
126 label="label",
127 use_ratio=False,
128 plot=False,
129 exit_code=77,
130 skip_import_verification=False,
131 temporary_directory=True,
132 **kwargs,
133 ):
134 super(DNest4, self).__init__(
135 likelihood=likelihood,
136 priors=priors,
137 outdir=outdir,
138 label=label,
139 use_ratio=use_ratio,
140 plot=plot,
141 skip_import_verification=skip_import_verification,
142 temporary_directory=temporary_directory,
143 exit_code=exit_code,
144 **kwargs,
145 )
147 self.num_particles = self.kwargs["num_particles"]
148 self.max_num_levels = self.kwargs["max_num_levels"]
149 self._verbose = self.kwargs["verbose"]
150 self._backend = self.kwargs["backend"]
152 self.start_time = np.nan
153 self.sampler = None
154 self._information = np.nan
156 # Get the estimates of the prior distributions' widths and centers.
157 widths = []
158 centers = []
159 highs = []
160 lows = []
162 samples = self.priors.sample(size=10000)
164 for key in self.search_parameter_keys:
165 pts = samples[key]
166 low = pts.min()
167 high = pts.max()
168 width = high - low
169 center = (high + low) / 2.0
170 widths.append(width)
171 centers.append(center)
173 highs.append(high)
174 lows.append(low)
176 self._widths = np.array(widths)
177 self._centers = np.array(centers)
178 self._highs = np.array(highs)
179 self._lows = np.array(lows)
181 self._dnest4_model = _DNest4Model(
182 self.log_likelihood,
183 self.get_random_draw_from_prior,
184 self._widths,
185 self._centers,
186 self._highs,
187 self._lows,
188 )
190 def _set_backend(self):
191 import dnest4
193 if self._backend == "csv":
194 return dnest4.backends.CSVBackend(
195 f"{self.outdir}/dnest4{self.label}/", sep=" "
196 )
197 else:
198 return dnest4.backends.MemoryBackend()
200 def _set_dnest4_kwargs(self):
201 dnest4_keys = ["num_steps", "new_level_interval", "lam", "beta", "seed"]
202 self.dnest4_kwargs = {key: self.kwargs[key] for key in dnest4_keys}
204 @signal_wrapper
205 def run_sampler(self):
206 import dnest4
208 self._set_dnest4_kwargs()
209 backend = self._set_backend()
211 self._verify_kwargs_against_default_kwargs()
212 self._setup_run_directory()
213 self._check_and_load_sampling_time_file()
214 self.start_time = time.time()
216 self.sampler = dnest4.DNest4Sampler(self._dnest4_model, backend=backend)
217 out = self.sampler.sample(
218 self.max_num_levels, num_particles=self.num_particles, **self.dnest4_kwargs
219 )
221 for i, sample in enumerate(out):
222 if self._verbose and ((i + 1) % 100 == 0):
223 stats = self.sampler.postprocess()
224 logger.info(f"Iteration: {i + 1} log(Z): {stats['log_Z']}")
226 self._calculate_and_save_sampling_time()
227 self._clean_up_run_directory()
229 stats = self.sampler.postprocess(resample=1)
230 self.result.log_evidence = stats["log_Z"]
231 self._information = stats["H"]
232 self.result.log_evidence_err = np.sqrt(self._information / self.num_particles)
233 self.result.samples = np.array(self.sampler.backend.posterior_samples)
235 self.result.sampler_output = out
236 self.result.outputfiles_basename = self.outputfiles_basename
237 self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
239 self.calc_likelihood_count()
241 return self.result
243 def _translate_kwargs(self, kwargs):
244 kwargs = super()._translate_kwargs(kwargs)
245 if "num_steps" not in kwargs:
246 for equiv in self.walks_equiv_kwargs:
247 if equiv in kwargs:
248 kwargs["num_steps"] = kwargs.pop(equiv)
250 def _verify_kwargs_against_default_kwargs(self):
251 self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None)
252 super(DNest4, self)._verify_kwargs_against_default_kwargs()