Coverage for bilby/core/sampler/nessai.py: 52%

153 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import os 

2import sys 

3 

4import numpy as np 

5from pandas import DataFrame 

6from scipy.special import logsumexp 

7 

8from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger 

9from .base_sampler import NestedSampler, signal_wrapper 

10 

11 

12class Nessai(NestedSampler): 

13 """bilby wrapper of nessai (https://github.com/mj-will/nessai) 

14 

15 All positional and keyword arguments passed to `run_sampler` are propagated 

16 to `nessai.flowsampler.FlowSampler` 

17 

18 See the documentation for an explanation of the different kwargs. 

19 

20 Documentation: https://nessai.readthedocs.io/ 

21 """ 

22 

23 sampler_name = "nessai" 

24 _default_kwargs = None 

25 _run_kwargs_list = None 

26 sampling_seed_key = "seed" 

27 

28 @property 

29 def run_kwargs_list(self): 

30 """List of kwargs used in the run method of :code:`FlowSampler`""" 

31 if not self._run_kwargs_list: 

32 from nessai.utils.bilbyutils import get_run_kwargs_list 

33 

34 self._run_kwargs_list = get_run_kwargs_list() 

35 ignored_kwargs = ["save"] 

36 for ik in ignored_kwargs: 

37 if ik in self._run_kwargs_list: 

38 self._run_kwargs_list.remove(ik) 

39 return self._run_kwargs_list 

40 

41 @property 

42 def default_kwargs(self): 

43 """Default kwargs for nessai. 

44 

45 Retrieves default values from nessai directly and then includes any 

46 bilby specific defaults. This avoids the need to update bilby when the 

47 defaults change or new kwargs are added to nessai. 

48 

49 Includes the following kwargs that are specific to bilby: 

50 

51 - :code:`nessai_log_level`: allows setting the logging level in nessai 

52 - :code:`nessai_logging_stream`: allows setting the logging stream 

53 - :code:`nessai_plot`: allows toggling the plotting in FlowSampler.run 

54 """ 

55 if not self._default_kwargs: 

56 from nessai.utils.bilbyutils import get_all_kwargs 

57 

58 kwargs = get_all_kwargs() 

59 

60 # Defaults for bilby that will override nessai defaults 

61 bilby_defaults = dict( 

62 output=None, 

63 exit_code=self.exit_code, 

64 nessai_log_level=None, 

65 nessai_logging_stream="stdout", 

66 nessai_plot=True, 

67 plot_posterior=False, # bilby already produces a posterior plot 

68 log_on_iteration=False, # Use periodic logging by default 

69 logging_interval=60, # Log every 60 seconds 

70 ) 

71 kwargs.update(bilby_defaults) 

72 # Kwargs that cannot be set in bilby 

73 remove = [ 

74 "save", 

75 "signal_handling", 

76 ] 

77 for k in remove: 

78 if k in kwargs: 

79 kwargs.pop(k) 

80 self._default_kwargs = kwargs 

81 return self._default_kwargs 

82 

83 def log_prior(self, theta): 

84 """ 

85 

86 Parameters 

87 ---------- 

88 theta: list 

89 List of sampled values on a unit interval 

90 

91 Returns 

92 ------- 

93 float: Joint ln prior probability of theta 

94 

95 """ 

96 return self.priors.ln_prob(theta, axis=0) 

97 

98 def get_nessai_model(self): 

99 """Get the model for nessai.""" 

100 from nessai.livepoint import dict_to_live_points 

101 from nessai.model import Model as BaseModel 

102 

103 class Model(BaseModel): 

104 """A wrapper class to pass our log_likelihood and priors into nessai 

105 

106 Parameters 

107 ---------- 

108 names : list of str 

109 List of parameters to sample 

110 priors : :obj:`bilby.core.prior.PriorDict` 

111 Priors to use for sampling. Needed for the bounds and the 

112 `sample` method. 

113 """ 

114 

115 def __init__(self, names, priors): 

116 self.names = names 

117 self.priors = priors 

118 self._update_bounds() 

119 

120 @staticmethod 

121 def log_likelihood(x, **kwargs): 

122 """Compute the log likelihood""" 

123 theta = [x[n].item() for n in self.search_parameter_keys] 

124 return self.log_likelihood(theta) 

125 

126 @staticmethod 

127 def log_prior(x, **kwargs): 

128 """Compute the log prior""" 

129 theta = {n: x[n] for n in self._search_parameter_keys} 

130 return self.log_prior(theta) 

131 

132 def _update_bounds(self): 

133 self.bounds = { 

134 key: [self.priors[key].minimum, self.priors[key].maximum] 

135 for key in self.names 

136 } 

137 

138 def new_point(self, N=1): 

139 """Draw a point from the prior""" 

140 prior_samples = self.priors.sample(size=N) 

141 samples = {n: prior_samples[n] for n in self.names} 

142 return dict_to_live_points(samples) 

143 

144 def new_point_log_prob(self, x): 

145 """Proposal probability for new the point""" 

146 return self.log_prior(x) 

147 

148 @staticmethod 

149 def from_unit_hypercube(x): 

150 """Map samples from the unit hypercube to the prior.""" 

151 theta = {} 

152 for n in self._search_parameter_keys: 

153 theta[n] = self.priors[n].rescale(x[n]) 

154 return dict_to_live_points(theta) 

155 

156 @staticmethod 

157 def to_unit_hypercube(x): 

158 """Map samples from the prior to the unit hypercube.""" 

159 theta = {n: x[n] for n in self._search_parameter_keys} 

160 return dict_to_live_points(self.priors.cdf(theta)) 

161 

162 model = Model(self.search_parameter_keys, self.priors) 

163 return model 

164 

165 def split_kwargs(self): 

166 """Split kwargs into configuration and run time kwargs""" 

167 kwargs = self.kwargs.copy() 

168 run_kwargs = {} 

169 for k in self.run_kwargs_list: 

170 run_kwargs[k] = kwargs.pop(k) 

171 run_kwargs["plot"] = kwargs.pop("nessai_plot") 

172 return kwargs, run_kwargs 

173 

174 def get_posterior_weights(self): 

175 """Get the posterior weights for the nested samples""" 

176 from nessai.posterior import compute_weights 

177 

178 _, log_weights = compute_weights( 

179 np.array(self.fs.nested_samples["logL"]), 

180 np.array(self.fs.ns.state.nlive), 

181 ) 

182 w = np.exp(log_weights - logsumexp(log_weights)) 

183 return w 

184 

185 def get_nested_samples(self): 

186 """Get the nested samples dataframe""" 

187 ns = DataFrame(self.fs.nested_samples) 

188 ns.rename( 

189 columns=dict(logL="log_likelihood", logP="log_prior", it="iteration"), 

190 inplace=True, 

191 ) 

192 return ns 

193 

194 def update_result(self): 

195 """Update the result object.""" 

196 from nessai.livepoint import live_points_to_array 

197 

198 # Manually set likelihood evaluations because parallelisation breaks the counter 

199 self.result.num_likelihood_evaluations = self.fs.ns.total_likelihood_evaluations 

200 

201 self.result.sampling_time = self.fs.ns.sampling_time 

202 self.result.samples = live_points_to_array( 

203 self.fs.posterior_samples, self.search_parameter_keys 

204 ) 

205 self.result.log_likelihood_evaluations = self.fs.posterior_samples["logL"] 

206 self.result.nested_samples = self.get_nested_samples() 

207 self.result.nested_samples["weights"] = self.get_posterior_weights() 

208 self.result.log_evidence = self.fs.log_evidence 

209 self.result.log_evidence_err = self.fs.log_evidence_error 

210 

211 @signal_wrapper 

212 def run_sampler(self): 

213 """Run the sampler. 

214 

215 Nessai is designed to be ran in two stages, initialise the sampler 

216 and then call the run method with additional configuration. This means 

217 there are effectively two sets of keyword arguments: one for 

218 initializing the sampler and the other for the run function. 

219 """ 

220 from nessai.flowsampler import FlowSampler 

221 from nessai.utils import setup_logger 

222 

223 kwargs, run_kwargs = self.split_kwargs() 

224 

225 # Setup the logger for nessai, use nessai_log_level if specified, else use 

226 # the level of the bilby logger. 

227 nessai_log_level = kwargs.pop("nessai_log_level") 

228 if nessai_log_level is None or nessai_log_level == "bilby": 

229 nessai_log_level = logger.getEffectiveLevel() 

230 nessai_logging_stream = kwargs.pop("nessai_logging_stream") 

231 

232 setup_logger( 

233 self.outdir, 

234 label=self.label, 

235 log_level=nessai_log_level, 

236 stream=nessai_logging_stream, 

237 ) 

238 

239 # Get the nessai model 

240 model = self.get_nessai_model() 

241 

242 # Configure the sampler 

243 self.fs = FlowSampler( 

244 model, 

245 signal_handling=False, # Disable signal handling so it can be handled by bilby 

246 **kwargs, 

247 ) 

248 # Run the sampler 

249 self.fs.run(**run_kwargs) 

250 

251 # Update the result 

252 self.update_result() 

253 

254 return self.result 

255 

256 def _translate_kwargs(self, kwargs): 

257 """Translate the keyword arguments""" 

258 super()._translate_kwargs(kwargs) 

259 if "nlive" not in kwargs: 

260 for equiv in self.npoints_equiv_kwargs: 

261 if equiv in kwargs: 

262 kwargs["nlive"] = kwargs.pop(equiv) 

263 if "n_pool" not in kwargs: 

264 for equiv in self.npool_equiv_kwargs: 

265 if equiv in kwargs: 

266 kwargs["n_pool"] = kwargs.pop(equiv) 

267 if "n_pool" not in kwargs: 

268 kwargs["n_pool"] = self._npool 

269 

270 def _verify_kwargs_against_default_kwargs(self): 

271 """Verify the keyword arguments""" 

272 if "config_file" in self.kwargs: 

273 d = load_json(self.kwargs["config_file"], None) 

274 self.kwargs.update(d) 

275 self.kwargs.pop("config_file") 

276 

277 if not self.kwargs["plot"]: 

278 self.kwargs["plot"] = self.plot 

279 

280 if not self.kwargs["output"]: 

281 self.kwargs["output"] = os.path.join( 

282 self.outdir, f"{self.label}_nessai", "" 

283 ) 

284 

285 check_directory_exists_and_if_not_mkdir(self.kwargs["output"]) 

286 NestedSampler._verify_kwargs_against_default_kwargs(self) 

287 

288 def write_current_state(self): 

289 """Write the current state of the sampler""" 

290 self.fs.ns.checkpoint() 

291 

292 def write_current_state_and_exit(self, signum=None, frame=None): 

293 """ 

294 Overwrites the base class to make sure that :code:`Nessai` terminates 

295 properly. 

296 """ 

297 if hasattr(self, "fs"): 

298 self.fs.terminate_run(code=signum) 

299 else: 

300 logger.warning("Sampler is not initialized") 

301 self._log_interruption(signum=signum) 

302 sys.exit(self.exit_code) 

303 

304 @classmethod 

305 def get_expected_outputs(cls, outdir=None, label=None): 

306 """Get lists of the expected outputs directories and files. 

307 

308 These are used by :code:`bilby_pipe` when transferring files via HTCondor. 

309 

310 Parameters 

311 ---------- 

312 outdir : str 

313 The output directory. 

314 label : str 

315 The label for the run. 

316 

317 Returns 

318 ------- 

319 list 

320 List of file names. This will be empty for nessai. 

321 list 

322 List of directory names. 

323 """ 

324 dirs = [os.path.join(outdir, f"{label}_{cls.sampler_name}", "")] 

325 dirs += [os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"]] 

326 filenames = [] 

327 return filenames, dirs 

328 

329 def _setup_pool(self): 

330 pass