Coverage for parallel_bilby/analysis/analysis_run.py: 81%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1import logging 

2import os 

3import pickle 

4 

5import bilby 

6import dynesty 

7import numpy as np 

8from bilby.core.sampler.base_sampler import _SamplingContainer 

9from bilby.core.sampler.dynesty import DynestySetupError, _set_sampling_kwargs 

10from bilby.core.sampler.dynesty_utils import ( 

11 AcceptanceTrackingRWalk, 

12 ACTTrackingRWalk, 

13 FixedRWalk, 

14 LivePointSampler, 

15 MultiEllipsoidLivePointSampler, 

16) 

17from bilby.core.utils import logger 

18from bilby_pipe.utils import convert_string_to_list 

19 

20from .likelihood import setup_likelihood 

21 

22 

23class AnalysisRun(object): 

24 """ 

25 An object with methods for driving the sampling run. 

26 

27 Parameters: arguments to set the output path and control the dynesty sampler. 

28 """ 

29 

30 def __init__( 

31 self, 

32 data_dump, 

33 outdir=None, 

34 label=None, 

35 dynesty_sample="acceptance-walk", 

36 nlive=5, 

37 dynesty_bound="live", 

38 walks=100, 

39 maxmcmc=5000, 

40 naccept=60, 

41 nact=2, 

42 facc=0.5, 

43 min_eff=10, 

44 enlarge=1.5, 

45 sampling_seed=0, 

46 proposals=None, 

47 bilby_zero_likelihood_mode=False, 

48 ): 

49 self.maxmcmc = maxmcmc 

50 self.nact = nact 

51 self.naccept = naccept 

52 self.proposals = convert_string_to_list(proposals) 

53 

54 # Read data dump from the pickle file 

55 with open(data_dump, "rb") as file: 

56 data_dump = pickle.load(file) 

57 

58 ifo_list = data_dump["ifo_list"] 

59 waveform_generator = data_dump["waveform_generator"] 

60 waveform_generator.start_time = ifo_list[0].time_array[0] 

61 args = data_dump["args"] 

62 injection_parameters = data_dump.get("injection_parameters", None) 

63 

64 args.weight_file = data_dump["meta_data"].get("weight_file", None) 

65 

66 # If the run dir has not been specified, get it from the args 

67 if outdir is None: 

68 outdir = args.outdir 

69 else: 

70 # Create the run dir 

71 os.makedirs(outdir, exist_ok=True) 

72 

73 # If the label has not been specified, get it from the args 

74 if label is None: 

75 label = args.label 

76 

77 priors = bilby.gw.prior.PriorDict.from_json(data_dump["prior_file"]) 

78 

79 logger.setLevel(logging.WARNING) 

80 likelihood = setup_likelihood( 

81 interferometers=ifo_list, 

82 waveform_generator=waveform_generator, 

83 priors=priors, 

84 args=args, 

85 ) 

86 priors.convert_floats_to_delta_functions() 

87 logger.setLevel(logging.INFO) 

88 

89 sampling_keys = [] 

90 for p in priors: 

91 if isinstance(priors[p], bilby.core.prior.Constraint): 

92 continue 

93 elif priors[p].is_fixed: 

94 likelihood.parameters[p] = priors[p].peak 

95 else: 

96 sampling_keys.append(p) 

97 

98 periodic = [] 

99 reflective = [] 

100 for ii, key in enumerate(sampling_keys): 

101 if priors[key].boundary == "periodic": 

102 logger.debug(f"Setting periodic boundary for {key}") 

103 periodic.append(ii) 

104 elif priors[key].boundary == "reflective": 

105 logger.debug(f"Setting reflective boundary for {key}") 

106 reflective.append(ii) 

107 

108 if len(periodic) == 0: 

109 periodic = None 

110 if len(reflective) == 0: 

111 reflective = None 

112 

113 self.init_sampler_kwargs = dict( 

114 nlive=nlive, 

115 sample=dynesty_sample, 

116 bound=dynesty_bound, 

117 walks=walks, 

118 facc=facc, 

119 first_update=dict(min_eff=min_eff, min_ncall=2 * nlive), 

120 enlarge=enlarge, 

121 ) 

122 

123 self._set_sampling_method() 

124 

125 # Create a random generator, which is saved across restarts 

126 # This ensures that runs are fully deterministic, which is important 

127 # for reproducibility 

128 self.sampling_seed = sampling_seed 

129 self.rstate = np.random.Generator(np.random.PCG64(self.sampling_seed)) 

130 logger.debug( 

131 f"Setting random state = {self.rstate} (seed={self.sampling_seed})" 

132 ) 

133 

134 self.outdir = outdir 

135 self.label = label 

136 self.data_dump = data_dump 

137 self.priors = priors 

138 self.sampling_keys = sampling_keys 

139 self.likelihood = likelihood 

140 self.zero_likelihood_mode = bilby_zero_likelihood_mode 

141 self.periodic = periodic 

142 self.reflective = reflective 

143 self.args = args 

144 self.injection_parameters = injection_parameters 

145 self.nlive = nlive 

146 

147 def _set_sampling_method(self): 

148 

149 sample = self.init_sampler_kwargs["sample"] 

150 bound = self.init_sampler_kwargs["bound"] 

151 

152 _set_sampling_kwargs((self.nact, self.maxmcmc, self.proposals, self.naccept)) 

153 

154 if sample not in ["rwalk", "act-walk", "acceptance-walk"] and bound in [ 

155 "live", 

156 "live-multi", 

157 ]: 

158 logger.info( 

159 "Live-point based bound method requested with dynesty sample " 

160 f"'{sample}', overwriting to 'multi'" 

161 ) 

162 self.init_sampler_kwargs["bound"] = "multi" 

163 elif bound == "live": 

164 dynesty.dynamicsampler._SAMPLERS["live"] = LivePointSampler 

165 elif bound == "live-multi": 

166 dynesty.dynamicsampler._SAMPLERS[ 

167 "live-multi" 

168 ] = MultiEllipsoidLivePointSampler 

169 elif sample == "acceptance-walk": 

170 raise DynestySetupError( 

171 "bound must be set to live or live-multi for sample=acceptance-walk" 

172 ) 

173 elif self.proposals is None: 

174 logger.warning( 

175 "No proposals specified using dynesty sampling, defaulting " 

176 "to 'volumetric'." 

177 ) 

178 self.proposals = ["volumetric"] 

179 _SamplingContainer.proposals = self.proposals 

180 elif "diff" in self.proposals: 

181 raise DynestySetupError( 

182 "bound must be set to live or live-multi to use differential " 

183 "evolution proposals" 

184 ) 

185 

186 if sample == "rwalk": 

187 logger.info( 

188 "Using the bilby-implemented rwalk sample method with ACT estimated walks. " 

189 f"An average of {2 * self.nact} steps will be accepted up to chain length " 

190 f"{self.maxmcmc}." 

191 ) 

192 if self.init_sampler_kwargs["walks"] > self.maxmcmc: 

193 raise DynestySetupError("You have maxmcmc < walks (minimum mcmc)") 

194 if self.nact < 1: 

195 raise DynestySetupError("Unable to run with nact < 1") 

196 dynesty.nestedsamplers._SAMPLING["rwalk"] = AcceptanceTrackingRWalk() 

197 elif sample == "acceptance-walk": 

198 logger.info( 

199 "Using the bilby-implemented rwalk sampling with an average of " 

200 f"{self.naccept} accepted steps per MCMC and maximum length {self.maxmcmc}" 

201 ) 

202 dynesty.nestedsamplers._SAMPLING["acceptance-walk"] = FixedRWalk() 

203 elif sample == "act-walk": 

204 logger.info( 

205 "Using the bilby-implemented rwalk sampling tracking the " 

206 f"autocorrelation function and thinning by " 

207 f"{self.nact} with maximum length {self.nact * self.maxmcmc}" 

208 ) 

209 dynesty.nestedsamplers._SAMPLING["act-walk"] = ACTTrackingRWalk() 

210 elif sample == "rwalk_dynesty": 

211 sample = sample.strip("_dynesty") 

212 self.init_sampler_kwargs["sample"] = sample 

213 logger.info(f"Using the dynesty-implemented {sample} sample method") 

214 

215 def prior_transform_function(self, u_array): 

216 """ 

217 Calls the bilby rescaling function on an array of values 

218 

219 Parameters 

220 ---------- 

221 u_array: (float, array-like) 

222 The values to rescale 

223 

224 Returns 

225 ------- 

226 (float, array-like) 

227 The rescaled values 

228 

229 """ 

230 return self.priors.rescale(self.sampling_keys, u_array) 

231 

232 def log_likelihood_function(self, v_array): 

233 """ 

234 Calculates the log(likelihood) 

235 

236 Parameters 

237 ---------- 

238 u_array: (float, array-like) 

239 The values to rescale 

240 

241 Returns 

242 ------- 

243 (float, array-like) 

244 The rescaled values 

245 

246 """ 

247 if self.zero_likelihood_mode: 

248 return 0 

249 parameters = {key: v for key, v in zip(self.sampling_keys, v_array)} 

250 if self.priors.evaluate_constraints(parameters) > 0: 

251 self.likelihood.parameters.update(parameters) 

252 return ( 

253 self.likelihood.log_likelihood() 

254 - self.likelihood.noise_log_likelihood() 

255 ) 

256 else: 

257 return np.nan_to_num(-np.inf) 

258 

259 def log_prior_function(self, v_array): 

260 """ 

261 Calculates the log of the prior 

262 

263 Parameters 

264 ---------- 

265 v_array: (float, array-like) 

266 The prior values 

267 

268 Returns 

269 ------- 

270 (float, array-like) 

271 The log probability of the values 

272 

273 """ 

274 params = {key: t for key, t in zip(self.sampling_keys, v_array)} 

275 return self.priors.ln_prob(params) 

276 

277 def get_initial_points_from_prior(self, pool, calculate_likelihood=True): 

278 """ 

279 Generates a set of initial points, drawn from the prior 

280 

281 Parameters 

282 ---------- 

283 pool: schwimmbad.MPIPool 

284 Schwimmbad pool for MPI parallelisation 

285 (pbilby implements a modified version: MPIPoolFast) 

286 

287 calculate_likelihood: bool 

288 Option to calculate the likelihood for the generated points 

289 (default: True) 

290 

291 Returns 

292 ------- 

293 (numpy.ndarraym, numpy.ndarray, numpy.ndarray, None) 

294 Returns a tuple (unit, theta, logl, blob) 

295 unit: point in the unit cube 

296 theta: scaled value 

297 logl: log(likelihood) 

298 blob: None 

299 

300 """ 

301 # Create a new rstate for each point, otherwise each task will generate 

302 # the same random number, and the rstate on master will not be incremented. 

303 # The argument to self.rstate.integers() is a very large integer. 

304 # These rstates aren't used after this map, but each time they are created, 

305 # a different (but deterministic) seed is used. 

306 sg = np.random.SeedSequence(self.rstate.integers(9223372036854775807)) 

307 map_rstates = [ 

308 np.random.Generator(np.random.PCG64(n)) for n in sg.spawn(self.nlive) 

309 ] 

310 ndim = len(self.sampling_keys) 

311 

312 args_list = [ 

313 ( 

314 self.prior_transform_function, 

315 self.log_prior_function, 

316 self.log_likelihood_function, 

317 ndim, 

318 calculate_likelihood, 

319 map_rstates[i], 

320 ) 

321 for i in range(self.nlive) 

322 ] 

323 initial_points = pool.map(self.get_initial_point_from_prior, args_list) 

324 u_list = [point[0] for point in initial_points] 

325 v_list = [point[1] for point in initial_points] 

326 l_list = [point[2] for point in initial_points] 

327 blobs = None 

328 

329 return np.array(u_list), np.array(v_list), np.array(l_list), blobs 

330 

331 @staticmethod 

332 def get_initial_point_from_prior(args): 

333 """ 

334 Draw initial points from the prior subject to constraints applied both to 

335 the prior and the likelihood. 

336 

337 We remove any points where the likelihood or prior is infinite or NaN. 

338 

339 The `log_likelihood_function` often converts infinite values to large 

340 finite values so we catch those. 

341 """ 

342 ( 

343 prior_transform_function, 

344 log_prior_function, 

345 log_likelihood_function, 

346 ndim, 

347 calculate_likelihood, 

348 rstate, 

349 ) = args 

350 bad_values = [np.inf, np.nan_to_num(np.inf), np.nan] 

351 while True: 

352 unit = rstate.random(ndim) 

353 theta = prior_transform_function(unit) 

354 

355 if abs(log_prior_function(theta)) not in bad_values: 

356 if calculate_likelihood: 

357 logl = log_likelihood_function(theta) 

358 if abs(logl) not in bad_values: 

359 return unit, theta, logl 

360 else: 

361 return unit, theta, np.nan 

362 

363 def get_nested_sampler(self, live_points, pool, pool_size): 

364 """ 

365 Returns the dynested nested sampler, getting most arguments 

366 from the object's attributes 

367 

368 Parameters 

369 ---------- 

370 live_points: (numpy.ndarraym, numpy.ndarray, numpy.ndarray) 

371 The set of live points, in the same format as returned by 

372 get_initial_points_from_prior 

373 

374 pool: schwimmbad.MPIPool 

375 Schwimmbad pool for MPI parallelisation 

376 (pbilby implements a modified version: MPIPoolFast) 

377 

378 pool_size: int 

379 Number of workers in the pool 

380 

381 Returns 

382 ------- 

383 dynesty.NestedSampler 

384 

385 """ 

386 ndim = len(self.sampling_keys) 

387 sampler = dynesty.NestedSampler( 

388 self.log_likelihood_function, 

389 self.prior_transform_function, 

390 ndim, 

391 pool=pool, 

392 queue_size=pool_size, 

393 periodic=self.periodic, 

394 reflective=self.reflective, 

395 live_points=live_points, 

396 rstate=self.rstate, 

397 use_pool=dict( 

398 update_bound=True, 

399 propose_point=True, 

400 prior_transform=True, 

401 loglikelihood=True, 

402 ), 

403 **self.init_sampler_kwargs, 

404 ) 

405 

406 return sampler