Coverage for bilby/core/sampler/ultranest.py: 45%

117 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-05-06 04:57 +0000

1import datetime 

2import inspect 

3import time 

4 

5import numpy as np 

6from pandas import DataFrame 

7 

8from ..utils import logger 

9from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper 

10 

11 

12class Ultranest(_TemporaryFileSamplerMixin, NestedSampler): 

13 """ 

14 bilby wrapper of ultranest 

15 (https://johannesbuchner.github.io/UltraNest/index.html) 

16 

17 All positional and keyword arguments (i.e., the args and kwargs) passed to 

18 `run_sampler` will be propagated to `ultranest.ReactiveNestedSampler.run` 

19 or `ultranest.NestedSampler.run`, see documentation for those classes for 

20 further help. Under Other Parameters, we list commonly used kwargs and the 

21 bilby defaults. If the number of live points is specified the 

22 `ultranest.NestedSampler` will be used, otherwise the 

23 `ultranest.ReactiveNestedSampler` will be used. 

24 

25 Parameters 

26 ========== 

27 num_live_points: int 

28 The number of live points, note this can also equivalently be given as 

29 one of [nlive, nlives, n_live_points, num_live_points]. If not given 

30 then the `ultranest.ReactiveNestedSampler` will be used, which does not 

31 require the number of live points to be specified. 

32 show_status: Bool 

33 If true, print information information about the convergence during 

34 resume: bool 

35 If true, resume run from checkpoint (if available) 

36 step_sampler: 

37 An UltraNest step sampler object. This defaults to None, so the default 

38 stepping behaviour is used. 

39 """ 

40 

41 sampler_name = "ultranest" 

42 abbreviation = "ultra" 

43 default_kwargs = dict( 

44 resume=True, 

45 show_status=True, 

46 num_live_points=None, 

47 wrapped_params=None, 

48 log_dir=None, 

49 derived_param_names=[], 

50 run_num=None, 

51 vectorized=False, 

52 num_test_samples=2, 

53 draw_multiple=True, 

54 num_bootstraps=30, 

55 update_interval_iter=None, 

56 update_interval_ncall=None, 

57 log_interval=None, 

58 dlogz=None, 

59 max_iters=None, 

60 update_interval_volume_fraction=0.2, 

61 viz_callback=None, 

62 dKL=0.5, 

63 frac_remain=0.01, 

64 Lepsilon=0.001, 

65 min_ess=400, 

66 max_ncalls=None, 

67 max_num_improvement_loops=-1, 

68 min_num_live_points=400, 

69 cluster_num_live_points=40, 

70 step_sampler=None, 

71 ) 

72 

73 short_name = "ultra" 

74 

75 def __init__( 

76 self, 

77 likelihood, 

78 priors, 

79 outdir="outdir", 

80 label="label", 

81 use_ratio=False, 

82 plot=False, 

83 exit_code=77, 

84 skip_import_verification=False, 

85 temporary_directory=True, 

86 callback_interval=10, 

87 **kwargs, 

88 ): 

89 super(Ultranest, self).__init__( 

90 likelihood=likelihood, 

91 priors=priors, 

92 outdir=outdir, 

93 label=label, 

94 use_ratio=use_ratio, 

95 plot=plot, 

96 skip_import_verification=skip_import_verification, 

97 exit_code=exit_code, 

98 temporary_directory=temporary_directory, 

99 **kwargs, 

100 ) 

101 self._apply_ultranest_boundaries() 

102 

103 if self.use_temporary_directory: 

104 # set callback interval, so copying of results does not thrash the 

105 # disk (ultranest will call viz_callback quite a lot) 

106 self.callback_interval = callback_interval 

107 

108 def _translate_kwargs(self, kwargs): 

109 kwargs = super()._translate_kwargs(kwargs) 

110 if "num_live_points" not in kwargs: 

111 for equiv in self.npoints_equiv_kwargs: 

112 if equiv in kwargs: 

113 kwargs["num_live_points"] = kwargs.pop(equiv) 

114 if "verbose" in kwargs and "show_status" not in kwargs: 

115 kwargs["show_status"] = kwargs.pop("verbose") 

116 resume = kwargs.get("resume", False) 

117 if resume is True: 

118 kwargs["resume"] = "overwrite" 

119 elif resume is False: 

120 kwargs["resume"] = "overwrite" 

121 

122 def _verify_kwargs_against_default_kwargs(self): 

123 """Check the kwargs""" 

124 

125 self.outputfiles_basename = self.kwargs.pop("log_dir", None) 

126 if self.kwargs["viz_callback"] is None: 

127 self.kwargs["viz_callback"] = self._viz_callback 

128 

129 NestedSampler._verify_kwargs_against_default_kwargs(self) 

130 

131 def _viz_callback(self, *args, **kwargs): 

132 if self.use_temporary_directory: 

133 if not (self._viz_callback_counter % self.callback_interval): 

134 self._copy_temporary_directory_contents_to_proper_path() 

135 self._calculate_and_save_sampling_time() 

136 self._viz_callback_counter += 1 

137 

138 def _apply_ultranest_boundaries(self): 

139 if ( 

140 self.kwargs["wrapped_params"] is None 

141 or len(self.kwargs.get("wrapped_params", [])) == 0 

142 ): 

143 self.kwargs["wrapped_params"] = [] 

144 for param, value in self.priors.items(): 

145 if param in self.search_parameter_keys: 

146 if value.boundary == "periodic": 

147 self.kwargs["wrapped_params"].append(1) 

148 else: 

149 self.kwargs["wrapped_params"].append(0) 

150 

151 def _copy_temporary_directory_contents_to_proper_path(self): 

152 """ 

153 Copy the temporary back to the proper path. 

154 Do not delete the temporary directory. 

155 """ 

156 if inspect.stack()[1].function != "_viz_callback": 

157 super(Ultranest, self)._copy_temporary_directory_contents_to_proper_path() 

158 

159 @property 

160 def sampler_function_kwargs(self): 

161 if self.kwargs.get("num_live_points", None) is not None: 

162 keys = [ 

163 "update_interval_iter", 

164 "update_interval_ncall", 

165 "log_interval", 

166 "dlogz", 

167 "max_iters", 

168 ] 

169 else: 

170 keys = [ 

171 "update_interval_volume_fraction", 

172 "update_interval_ncall", 

173 "log_interval", 

174 "show_status", 

175 "viz_callback", 

176 "dlogz", 

177 "dKL", 

178 "frac_remain", 

179 "Lepsilon", 

180 "min_ess", 

181 "max_iters", 

182 "max_ncalls", 

183 "max_num_improvement_loops", 

184 "min_num_live_points", 

185 "cluster_num_live_points", 

186 ] 

187 

188 function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} 

189 

190 return function_kwargs 

191 

192 @property 

193 def sampler_init_kwargs(self): 

194 keys = [ 

195 "derived_param_names", 

196 "resume", 

197 "run_num", 

198 "vectorized", 

199 "log_dir", 

200 "wrapped_params", 

201 ] 

202 if self.kwargs.get("num_live_points", None) is not None: 

203 keys += ["num_live_points"] 

204 else: 

205 keys += ["num_test_samples", "draw_multiple", "num_bootstraps"] 

206 

207 init_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} 

208 

209 return init_kwargs 

210 

211 @signal_wrapper 

212 def run_sampler(self): 

213 import ultranest 

214 import ultranest.stepsampler 

215 

216 if self.kwargs["dlogz"] is None: 

217 # remove dlogz, so ultranest defaults (which are different for 

218 # NestedSampler and ReactiveNestedSampler) are used 

219 self.kwargs.pop("dlogz") 

220 

221 self._verify_kwargs_against_default_kwargs() 

222 

223 stepsampler = self.kwargs.pop("step_sampler", None) 

224 

225 self._setup_run_directory() 

226 self.kwargs["log_dir"] = self.kwargs["outputfiles_basename"] 

227 self._check_and_load_sampling_time_file() 

228 

229 # use reactive nested sampler when no live points are given 

230 if self.kwargs.get("num_live_points", None) is not None: 

231 integrator = ultranest.integrator.NestedSampler 

232 else: 

233 integrator = ultranest.integrator.ReactiveNestedSampler 

234 

235 sampler = integrator( 

236 self.search_parameter_keys, 

237 self.log_likelihood, 

238 transform=self.prior_transform, 

239 **self.sampler_init_kwargs, 

240 ) 

241 

242 if stepsampler is not None: 

243 if isinstance(stepsampler, ultranest.stepsampler.StepSampler): 

244 sampler.stepsampler = stepsampler 

245 else: 

246 logger.warning( 

247 "The supplied step sampler is not the correct type. " 

248 "The default step sampling will be used instead." 

249 ) 

250 

251 if self.use_temporary_directory: 

252 self._viz_callback_counter = 1 

253 

254 self.start_time = time.time() 

255 results = sampler.run(**self.sampler_function_kwargs) 

256 self._calculate_and_save_sampling_time() 

257 

258 self._clean_up_run_directory() 

259 

260 self._generate_result(results) 

261 self.calc_likelihood_count() 

262 

263 return self.result 

264 

265 def _generate_result(self, out): 

266 # extract results 

267 from ..utils.random import rng 

268 

269 data = np.array(out["weighted_samples"]["points"]) 

270 weights = np.array(out["weighted_samples"]["weights"]) 

271 

272 scaledweights = weights / weights.max() 

273 mask = rng.uniform(0, 1, len(scaledweights)) < scaledweights 

274 

275 nested_samples = DataFrame(data, columns=self.search_parameter_keys) 

276 nested_samples["weights"] = weights 

277 nested_samples["log_likelihood"] = out["weighted_samples"]["logl"] 

278 self.result.log_likelihood_evaluations = np.array( 

279 out["weighted_samples"]["logl"] 

280 )[mask] 

281 self.result.sampler_output = out 

282 self.result.samples = data[mask, :] 

283 self.result.nested_samples = nested_samples 

284 self.result.log_evidence = out["logz"] 

285 self.result.log_evidence_err = out["logzerr"] 

286 if self.kwargs["num_live_points"] is not None: 

287 self.result.information_gain = ( 

288 np.power(out["logzerr"], 2) * self.kwargs["num_live_points"] 

289 ) 

290 

291 self.result.outputfiles_basename = self.outputfiles_basename 

292 self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time) 

293 

294 def log_likelihood(self, theta): 

295 log_l = super(Ultranest, self).log_likelihood(theta=theta) 

296 return np.nan_to_num(log_l)