Coverage for bilby/core/sampler/ultranest.py: 45%
117 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import datetime
2import inspect
3import time
5import numpy as np
6from pandas import DataFrame
8from ..utils import logger
9from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
12class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
13 """
14 bilby wrapper of ultranest
15 (https://johannesbuchner.github.io/UltraNest/index.html)
17 All positional and keyword arguments (i.e., the args and kwargs) passed to
18 `run_sampler` will be propagated to `ultranest.ReactiveNestedSampler.run`
19 or `ultranest.NestedSampler.run`, see documentation for those classes for
20 further help. Under Other Parameters, we list commonly used kwargs and the
21 bilby defaults. If the number of live points is specified the
22 `ultranest.NestedSampler` will be used, otherwise the
23 `ultranest.ReactiveNestedSampler` will be used.
25 Parameters
26 ==========
27 num_live_points: int
28 The number of live points, note this can also equivalently be given as
29 one of [nlive, nlives, n_live_points, num_live_points]. If not given
30 then the `ultranest.ReactiveNestedSampler` will be used, which does not
31 require the number of live points to be specified.
32 show_status: Bool
33 If true, print information information about the convergence during
34 resume: bool
35 If true, resume run from checkpoint (if available)
36 step_sampler:
37 An UltraNest step sampler object. This defaults to None, so the default
38 stepping behaviour is used.
39 """
41 sampler_name = "ultranest"
42 abbreviation = "ultra"
43 default_kwargs = dict(
44 resume=True,
45 show_status=True,
46 num_live_points=None,
47 wrapped_params=None,
48 log_dir=None,
49 derived_param_names=[],
50 run_num=None,
51 vectorized=False,
52 num_test_samples=2,
53 draw_multiple=True,
54 num_bootstraps=30,
55 update_interval_iter=None,
56 update_interval_ncall=None,
57 log_interval=None,
58 dlogz=None,
59 max_iters=None,
60 update_interval_volume_fraction=0.2,
61 viz_callback=None,
62 dKL=0.5,
63 frac_remain=0.01,
64 Lepsilon=0.001,
65 min_ess=400,
66 max_ncalls=None,
67 max_num_improvement_loops=-1,
68 min_num_live_points=400,
69 cluster_num_live_points=40,
70 step_sampler=None,
71 )
73 short_name = "ultra"
75 def __init__(
76 self,
77 likelihood,
78 priors,
79 outdir="outdir",
80 label="label",
81 use_ratio=False,
82 plot=False,
83 exit_code=77,
84 skip_import_verification=False,
85 temporary_directory=True,
86 callback_interval=10,
87 **kwargs,
88 ):
89 super(Ultranest, self).__init__(
90 likelihood=likelihood,
91 priors=priors,
92 outdir=outdir,
93 label=label,
94 use_ratio=use_ratio,
95 plot=plot,
96 skip_import_verification=skip_import_verification,
97 exit_code=exit_code,
98 temporary_directory=temporary_directory,
99 **kwargs,
100 )
101 self._apply_ultranest_boundaries()
103 if self.use_temporary_directory:
104 # set callback interval, so copying of results does not thrash the
105 # disk (ultranest will call viz_callback quite a lot)
106 self.callback_interval = callback_interval
108 def _translate_kwargs(self, kwargs):
109 kwargs = super()._translate_kwargs(kwargs)
110 if "num_live_points" not in kwargs:
111 for equiv in self.npoints_equiv_kwargs:
112 if equiv in kwargs:
113 kwargs["num_live_points"] = kwargs.pop(equiv)
114 if "verbose" in kwargs and "show_status" not in kwargs:
115 kwargs["show_status"] = kwargs.pop("verbose")
116 resume = kwargs.get("resume", False)
117 if resume is True:
118 kwargs["resume"] = "overwrite"
119 elif resume is False:
120 kwargs["resume"] = "overwrite"
122 def _verify_kwargs_against_default_kwargs(self):
123 """Check the kwargs"""
125 self.outputfiles_basename = self.kwargs.pop("log_dir", None)
126 if self.kwargs["viz_callback"] is None:
127 self.kwargs["viz_callback"] = self._viz_callback
129 NestedSampler._verify_kwargs_against_default_kwargs(self)
131 def _viz_callback(self, *args, **kwargs):
132 if self.use_temporary_directory:
133 if not (self._viz_callback_counter % self.callback_interval):
134 self._copy_temporary_directory_contents_to_proper_path()
135 self._calculate_and_save_sampling_time()
136 self._viz_callback_counter += 1
138 def _apply_ultranest_boundaries(self):
139 if (
140 self.kwargs["wrapped_params"] is None
141 or len(self.kwargs.get("wrapped_params", [])) == 0
142 ):
143 self.kwargs["wrapped_params"] = []
144 for param, value in self.priors.items():
145 if param in self.search_parameter_keys:
146 if value.boundary == "periodic":
147 self.kwargs["wrapped_params"].append(1)
148 else:
149 self.kwargs["wrapped_params"].append(0)
151 def _copy_temporary_directory_contents_to_proper_path(self):
152 """
153 Copy the temporary back to the proper path.
154 Do not delete the temporary directory.
155 """
156 if inspect.stack()[1].function != "_viz_callback":
157 super(Ultranest, self)._copy_temporary_directory_contents_to_proper_path()
159 @property
160 def sampler_function_kwargs(self):
161 if self.kwargs.get("num_live_points", None) is not None:
162 keys = [
163 "update_interval_iter",
164 "update_interval_ncall",
165 "log_interval",
166 "dlogz",
167 "max_iters",
168 ]
169 else:
170 keys = [
171 "update_interval_volume_fraction",
172 "update_interval_ncall",
173 "log_interval",
174 "show_status",
175 "viz_callback",
176 "dlogz",
177 "dKL",
178 "frac_remain",
179 "Lepsilon",
180 "min_ess",
181 "max_iters",
182 "max_ncalls",
183 "max_num_improvement_loops",
184 "min_num_live_points",
185 "cluster_num_live_points",
186 ]
188 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
190 return function_kwargs
192 @property
193 def sampler_init_kwargs(self):
194 keys = [
195 "derived_param_names",
196 "resume",
197 "run_num",
198 "vectorized",
199 "log_dir",
200 "wrapped_params",
201 ]
202 if self.kwargs.get("num_live_points", None) is not None:
203 keys += ["num_live_points"]
204 else:
205 keys += ["num_test_samples", "draw_multiple", "num_bootstraps"]
207 init_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
209 return init_kwargs
211 @signal_wrapper
212 def run_sampler(self):
213 import ultranest
214 import ultranest.stepsampler
216 if self.kwargs["dlogz"] is None:
217 # remove dlogz, so ultranest defaults (which are different for
218 # NestedSampler and ReactiveNestedSampler) are used
219 self.kwargs.pop("dlogz")
221 self._verify_kwargs_against_default_kwargs()
223 stepsampler = self.kwargs.pop("step_sampler", None)
225 self._setup_run_directory()
226 self.kwargs["log_dir"] = self.kwargs["outputfiles_basename"]
227 self._check_and_load_sampling_time_file()
229 # use reactive nested sampler when no live points are given
230 if self.kwargs.get("num_live_points", None) is not None:
231 integrator = ultranest.integrator.NestedSampler
232 else:
233 integrator = ultranest.integrator.ReactiveNestedSampler
235 sampler = integrator(
236 self.search_parameter_keys,
237 self.log_likelihood,
238 transform=self.prior_transform,
239 **self.sampler_init_kwargs,
240 )
242 if stepsampler is not None:
243 if isinstance(stepsampler, ultranest.stepsampler.StepSampler):
244 sampler.stepsampler = stepsampler
245 else:
246 logger.warning(
247 "The supplied step sampler is not the correct type. "
248 "The default step sampling will be used instead."
249 )
251 if self.use_temporary_directory:
252 self._viz_callback_counter = 1
254 self.start_time = time.time()
255 results = sampler.run(**self.sampler_function_kwargs)
256 self._calculate_and_save_sampling_time()
258 self._clean_up_run_directory()
260 self._generate_result(results)
261 self.calc_likelihood_count()
263 return self.result
265 def _generate_result(self, out):
266 # extract results
267 from ..utils.random import rng
269 data = np.array(out["weighted_samples"]["points"])
270 weights = np.array(out["weighted_samples"]["weights"])
272 scaledweights = weights / weights.max()
273 mask = rng.uniform(0, 1, len(scaledweights)) < scaledweights
275 nested_samples = DataFrame(data, columns=self.search_parameter_keys)
276 nested_samples["weights"] = weights
277 nested_samples["log_likelihood"] = out["weighted_samples"]["logl"]
278 self.result.log_likelihood_evaluations = np.array(
279 out["weighted_samples"]["logl"]
280 )[mask]
281 self.result.sampler_output = out
282 self.result.samples = data[mask, :]
283 self.result.nested_samples = nested_samples
284 self.result.log_evidence = out["logz"]
285 self.result.log_evidence_err = out["logzerr"]
286 if self.kwargs["num_live_points"] is not None:
287 self.result.information_gain = (
288 np.power(out["logzerr"], 2) * self.kwargs["num_live_points"]
289 )
291 self.result.outputfiles_basename = self.outputfiles_basename
292 self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
294 def log_likelihood(self, theta):
295 log_l = super(Ultranest, self).log_likelihood(theta=theta)
296 return np.nan_to_num(log_l)