Coverage for bilby/core/sampler/nessai.py: 52%
153 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2025-05-06 04:57 +0000
1import os
2import sys
4import numpy as np
5from pandas import DataFrame
6from scipy.special import logsumexp
8from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger
9from .base_sampler import NestedSampler, signal_wrapper
12class Nessai(NestedSampler):
13 """bilby wrapper of nessai (https://github.com/mj-will/nessai)
15 All positional and keyword arguments passed to `run_sampler` are propagated
16 to `nessai.flowsampler.FlowSampler`
18 See the documentation for an explanation of the different kwargs.
20 Documentation: https://nessai.readthedocs.io/
21 """
23 sampler_name = "nessai"
24 _default_kwargs = None
25 _run_kwargs_list = None
26 sampling_seed_key = "seed"
28 @property
29 def run_kwargs_list(self):
30 """List of kwargs used in the run method of :code:`FlowSampler`"""
31 if not self._run_kwargs_list:
32 from nessai.utils.bilbyutils import get_run_kwargs_list
34 self._run_kwargs_list = get_run_kwargs_list()
35 ignored_kwargs = ["save"]
36 for ik in ignored_kwargs:
37 if ik in self._run_kwargs_list:
38 self._run_kwargs_list.remove(ik)
39 return self._run_kwargs_list
41 @property
42 def default_kwargs(self):
43 """Default kwargs for nessai.
45 Retrieves default values from nessai directly and then includes any
46 bilby specific defaults. This avoids the need to update bilby when the
47 defaults change or new kwargs are added to nessai.
49 Includes the following kwargs that are specific to bilby:
51 - :code:`nessai_log_level`: allows setting the logging level in nessai
52 - :code:`nessai_logging_stream`: allows setting the logging stream
53 - :code:`nessai_plot`: allows toggling the plotting in FlowSampler.run
54 """
55 if not self._default_kwargs:
56 from nessai.utils.bilbyutils import get_all_kwargs
58 kwargs = get_all_kwargs()
60 # Defaults for bilby that will override nessai defaults
61 bilby_defaults = dict(
62 output=None,
63 exit_code=self.exit_code,
64 nessai_log_level=None,
65 nessai_logging_stream="stdout",
66 nessai_plot=True,
67 plot_posterior=False, # bilby already produces a posterior plot
68 log_on_iteration=False, # Use periodic logging by default
69 logging_interval=60, # Log every 60 seconds
70 )
71 kwargs.update(bilby_defaults)
72 # Kwargs that cannot be set in bilby
73 remove = [
74 "save",
75 "signal_handling",
76 ]
77 for k in remove:
78 if k in kwargs:
79 kwargs.pop(k)
80 self._default_kwargs = kwargs
81 return self._default_kwargs
83 def log_prior(self, theta):
84 """
86 Parameters
87 ----------
88 theta: list
89 List of sampled values on a unit interval
91 Returns
92 -------
93 float: Joint ln prior probability of theta
95 """
96 return self.priors.ln_prob(theta, axis=0)
98 def get_nessai_model(self):
99 """Get the model for nessai."""
100 from nessai.livepoint import dict_to_live_points
101 from nessai.model import Model as BaseModel
103 class Model(BaseModel):
104 """A wrapper class to pass our log_likelihood and priors into nessai
106 Parameters
107 ----------
108 names : list of str
109 List of parameters to sample
110 priors : :obj:`bilby.core.prior.PriorDict`
111 Priors to use for sampling. Needed for the bounds and the
112 `sample` method.
113 """
115 def __init__(self, names, priors):
116 self.names = names
117 self.priors = priors
118 self._update_bounds()
120 @staticmethod
121 def log_likelihood(x, **kwargs):
122 """Compute the log likelihood"""
123 theta = [x[n].item() for n in self.search_parameter_keys]
124 return self.log_likelihood(theta)
126 @staticmethod
127 def log_prior(x, **kwargs):
128 """Compute the log prior"""
129 theta = {n: x[n] for n in self._search_parameter_keys}
130 return self.log_prior(theta)
132 def _update_bounds(self):
133 self.bounds = {
134 key: [self.priors[key].minimum, self.priors[key].maximum]
135 for key in self.names
136 }
138 def new_point(self, N=1):
139 """Draw a point from the prior"""
140 prior_samples = self.priors.sample(size=N)
141 samples = {n: prior_samples[n] for n in self.names}
142 return dict_to_live_points(samples)
144 def new_point_log_prob(self, x):
145 """Proposal probability for new the point"""
146 return self.log_prior(x)
148 @staticmethod
149 def from_unit_hypercube(x):
150 """Map samples from the unit hypercube to the prior."""
151 theta = {}
152 for n in self._search_parameter_keys:
153 theta[n] = self.priors[n].rescale(x[n])
154 return dict_to_live_points(theta)
156 @staticmethod
157 def to_unit_hypercube(x):
158 """Map samples from the prior to the unit hypercube."""
159 theta = {n: x[n] for n in self._search_parameter_keys}
160 return dict_to_live_points(self.priors.cdf(theta))
162 model = Model(self.search_parameter_keys, self.priors)
163 return model
165 def split_kwargs(self):
166 """Split kwargs into configuration and run time kwargs"""
167 kwargs = self.kwargs.copy()
168 run_kwargs = {}
169 for k in self.run_kwargs_list:
170 run_kwargs[k] = kwargs.pop(k)
171 run_kwargs["plot"] = kwargs.pop("nessai_plot")
172 return kwargs, run_kwargs
174 def get_posterior_weights(self):
175 """Get the posterior weights for the nested samples"""
176 from nessai.posterior import compute_weights
178 _, log_weights = compute_weights(
179 np.array(self.fs.nested_samples["logL"]),
180 np.array(self.fs.ns.state.nlive),
181 )
182 w = np.exp(log_weights - logsumexp(log_weights))
183 return w
185 def get_nested_samples(self):
186 """Get the nested samples dataframe"""
187 ns = DataFrame(self.fs.nested_samples)
188 ns.rename(
189 columns=dict(logL="log_likelihood", logP="log_prior", it="iteration"),
190 inplace=True,
191 )
192 return ns
194 def update_result(self):
195 """Update the result object."""
196 from nessai.livepoint import live_points_to_array
198 # Manually set likelihood evaluations because parallelisation breaks the counter
199 self.result.num_likelihood_evaluations = self.fs.ns.total_likelihood_evaluations
201 self.result.sampling_time = self.fs.ns.sampling_time
202 self.result.samples = live_points_to_array(
203 self.fs.posterior_samples, self.search_parameter_keys
204 )
205 self.result.log_likelihood_evaluations = self.fs.posterior_samples["logL"]
206 self.result.nested_samples = self.get_nested_samples()
207 self.result.nested_samples["weights"] = self.get_posterior_weights()
208 self.result.log_evidence = self.fs.log_evidence
209 self.result.log_evidence_err = self.fs.log_evidence_error
211 @signal_wrapper
212 def run_sampler(self):
213 """Run the sampler.
215 Nessai is designed to be ran in two stages, initialise the sampler
216 and then call the run method with additional configuration. This means
217 there are effectively two sets of keyword arguments: one for
218 initializing the sampler and the other for the run function.
219 """
220 from nessai.flowsampler import FlowSampler
221 from nessai.utils import setup_logger
223 kwargs, run_kwargs = self.split_kwargs()
225 # Setup the logger for nessai, use nessai_log_level if specified, else use
226 # the level of the bilby logger.
227 nessai_log_level = kwargs.pop("nessai_log_level")
228 if nessai_log_level is None or nessai_log_level == "bilby":
229 nessai_log_level = logger.getEffectiveLevel()
230 nessai_logging_stream = kwargs.pop("nessai_logging_stream")
232 setup_logger(
233 self.outdir,
234 label=self.label,
235 log_level=nessai_log_level,
236 stream=nessai_logging_stream,
237 )
239 # Get the nessai model
240 model = self.get_nessai_model()
242 # Configure the sampler
243 self.fs = FlowSampler(
244 model,
245 signal_handling=False, # Disable signal handling so it can be handled by bilby
246 **kwargs,
247 )
248 # Run the sampler
249 self.fs.run(**run_kwargs)
251 # Update the result
252 self.update_result()
254 return self.result
256 def _translate_kwargs(self, kwargs):
257 """Translate the keyword arguments"""
258 super()._translate_kwargs(kwargs)
259 if "nlive" not in kwargs:
260 for equiv in self.npoints_equiv_kwargs:
261 if equiv in kwargs:
262 kwargs["nlive"] = kwargs.pop(equiv)
263 if "n_pool" not in kwargs:
264 for equiv in self.npool_equiv_kwargs:
265 if equiv in kwargs:
266 kwargs["n_pool"] = kwargs.pop(equiv)
267 if "n_pool" not in kwargs:
268 kwargs["n_pool"] = self._npool
270 def _verify_kwargs_against_default_kwargs(self):
271 """Verify the keyword arguments"""
272 if "config_file" in self.kwargs:
273 d = load_json(self.kwargs["config_file"], None)
274 self.kwargs.update(d)
275 self.kwargs.pop("config_file")
277 if not self.kwargs["plot"]:
278 self.kwargs["plot"] = self.plot
280 if not self.kwargs["output"]:
281 self.kwargs["output"] = os.path.join(
282 self.outdir, f"{self.label}_nessai", ""
283 )
285 check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
286 NestedSampler._verify_kwargs_against_default_kwargs(self)
288 def write_current_state(self):
289 """Write the current state of the sampler"""
290 self.fs.ns.checkpoint()
292 def write_current_state_and_exit(self, signum=None, frame=None):
293 """
294 Overwrites the base class to make sure that :code:`Nessai` terminates
295 properly.
296 """
297 if hasattr(self, "fs"):
298 self.fs.terminate_run(code=signum)
299 else:
300 logger.warning("Sampler is not initialized")
301 self._log_interruption(signum=signum)
302 sys.exit(self.exit_code)
304 @classmethod
305 def get_expected_outputs(cls, outdir=None, label=None):
306 """Get lists of the expected outputs directories and files.
308 These are used by :code:`bilby_pipe` when transferring files via HTCondor.
310 Parameters
311 ----------
312 outdir : str
313 The output directory.
314 label : str
315 The label for the run.
317 Returns
318 -------
319 list
320 List of file names. This will be empty for nessai.
321 list
322 List of directory names.
323 """
324 dirs = [os.path.join(outdir, f"{label}_{cls.sampler_name}", "")]
325 dirs += [os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"]]
326 filenames = []
327 return filenames, dirs
329 def _setup_pool(self):
330 pass