Coverage for parallel_bilby/schwimmbad_fast.py: 92%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

238 statements  

1import atexit 

2import datetime 

3import json 

4import sys 

5import timeit 

6import traceback 

7 

8from schwimmbad import _VERBOSE, MPIPool, log 

9 

10 

11def _dummy_callback(x): 

12 pass 

13 

14 

15def _import_mpi(quiet=False, use_dill=False): 

16 global MPI 

17 import mpi4py 

18 

19 mpi4py.rc.threads = False 

20 mpi4py.rc.recv_mprobe = False 

21 from mpi4py import MPI as _MPI 

22 

23 if use_dill: 

24 import dill 

25 

26 _MPI.pickle.__init__(dill.dumps, dill.loads, dill.HIGHEST_PROTOCOL) 

27 MPI = _MPI 

28 

29 return MPI 

30 

31 

32class MPIPoolFast(MPIPool): 

33 """A processing pool with persistent MPI tasks. 

34 

35 Schwimmbad's MPI Pool starts the worker threads waiting in __init__ 

36 but then finally does sys.exit(0), so those threads never get a 

37 chance to do anything else. 

38 

39 This fix will behave like MPIPool as default, but using the 

40 parameters, the MPI worker tasks can be allowed to persist 

41 beyond the pool. 

42 """ 

43 

44 def __init__( 

45 self, 

46 comm=None, 

47 use_dill=False, 

48 begin_wait=True, 

49 persistent_tasks=True, 

50 parallel_comms=False, 

51 time_mpi=False, 

52 timing_interval=0, 

53 ): 

54 MPI = _import_mpi(use_dill=use_dill) 

55 

56 if comm is None: 

57 comm = MPI.COMM_WORLD 

58 self.comm = comm 

59 

60 self.master = 0 

61 self.rank = self.comm.Get_rank() 

62 

63 self.pool_open = True 

64 atexit.register(lambda: MPIPool.close(self)) 

65 

66 # Option to enable parallel communication 

67 self.parallel_comms = parallel_comms 

68 

69 # Initialise timer 

70 self.time_mpi = time_mpi 

71 if self.time_mpi: 

72 self.timer = Timer(self.rank, self.comm, self.master) 

73 else: 

74 self.timer = NullTimer() 

75 

76 # Periodically save the timing output on only the first worker task (specify in seconds) 

77 if self.rank == 1: 

78 self.timing_interval = timing_interval 

79 else: 

80 self.timing_interval = 0 

81 

82 if self.timing_interval == 0: 

83 self.timing_interval = False 

84 

85 if not self.is_master(): 

86 if begin_wait: 

87 # workers branch here and wait for work 

88 try: 

89 self.wait() 

90 except Exception: 

91 print(f"worker with rank {self.rank} crashed".center(80, "=")) 

92 traceback.print_exc() 

93 sys.stdout.flush() 

94 sys.stderr.flush() 

95 # shutdown all mpi tasks: 

96 from mpi4py import MPI 

97 

98 MPI.COMM_WORLD.Abort() 

99 finally: 

100 if not persistent_tasks: 

101 sys.exit(0) 

102 

103 else: 

104 self.workers = set(range(self.comm.size)) 

105 self.workers.discard(self.master) 

106 self.size = self.comm.Get_size() - 1 

107 

108 if self.size == 0: 

109 raise ValueError( 

110 "Tried to create an MPI pool, but there " 

111 "was only one MPI process available. " 

112 "Need at least two." 

113 ) 

114 

115 @staticmethod 

116 def enabled(): 

117 if MPI is None: 

118 _import_mpi(quiet=True) 

119 if MPI is not None: 

120 if MPI.COMM_WORLD.size > 1: 

121 return True 

122 return False 

123 

124 def wait(self, callback=None): 

125 """Tell the workers to wait and listen for the master process. This is 

126 called automatically when using :meth:`MPIPool.map` and doesn't need to 

127 be called by the user. 

128 """ 

129 if self.is_master(): 

130 return 

131 

132 worker = self.comm.rank 

133 status = MPI.Status() 

134 

135 if self.timing_interval: 

136 time_snapshots = [] 

137 self.timer.start("walltime") 

138 

139 # Flag if master is performing a serial task (True by default) 

140 master_serial = True 

141 self.timer.start("master_serial") 

142 

143 while True: 

144 # Receive task 

145 self.timer.start( 

146 "mpi_recv" 

147 ) # recv timer only gets counted if entering into a parallel task 

148 if not master_serial: 

149 self.timer.start( 

150 "barrier" 

151 ) # start the barrier timer in case this is the last parallel task 

152 

153 log.log(_VERBOSE, f"Worker {worker} waiting for task") 

154 task = self.comm.recv(source=self.master, tag=MPI.ANY_TAG, status=status) 

155 

156 # Indicator from master that a serial task is being performed 

157 if task == "s": 

158 self.timer.stop("barrier") # count recv time towards barrier 

159 self.timer.start("master_serial") 

160 master_serial = True 

161 elif task == "p": 

162 self.timer.stop("master_serial") 

163 master_serial = False 

164 else: 

165 # Process task 

166 if task is None: 

167 log.log(_VERBOSE, f"Worker {worker} told to quit work") 

168 

169 if master_serial: 

170 self.timer.stop("master_serial") 

171 else: 

172 self.timer.stop("barrier") 

173 

174 break 

175 

176 if master_serial and self.time_mpi: 

177 print( 

178 "Warning: Serial section has been flagged, but not unflagged yet. Timing will be inaccurate." 

179 ) 

180 self.timer.stop("mpi_recv") 

181 

182 self.timer.start("compute") 

183 func, arg = task 

184 log.log( 

185 _VERBOSE, 

186 f"Worker {worker} got task {arg} with tag {status.tag}", 

187 ) 

188 

189 result = func(arg) 

190 self.timer.stop("compute") 

191 

192 # Return results 

193 self.timer.start("mpi_send") 

194 log.log( 

195 _VERBOSE, 

196 f"Worker {worker} sending answer {result} with tag {status.tag}", 

197 ) 

198 

199 self.comm.ssend(result, self.master, status.tag) 

200 self.timer.stop("mpi_send") 

201 

202 if self.timing_interval: 

203 self.timer.stop("walltime") 

204 if self.timer.interval_time["walltime"] > self.timing_interval: 

205 time_snapshots += [self.timer.interval_time.copy()] 

206 self.timer.reset() 

207 self.timer.start("walltime") 

208 

209 if self.timing_interval: 

210 with open("mpi_worker_timing.json", "w") as f: 

211 json.dump(time_snapshots, f) 

212 

213 if callback is not None: 

214 callback() 

215 

216 def map(self, worker, tasks, callback=None): 

217 """Evaluate a function or callable on each task in parallel using MPI. 

218 

219 The callable, ``worker``, is called on each element of the ``tasks`` 

220 iterable. The results are returned in the expected order (symmetric with 

221 ``tasks``). 

222 

223 Parameters 

224 ---------- 

225 worker : callable 

226 A function or callable object that is executed on each element of 

227 the specified ``tasks`` iterable. This object must be picklable 

228 (i.e. it can't be a function scoped within a function or a 

229 ``lambda`` function). This should accept a single positional 

230 argument and return a single object. 

231 tasks : iterable 

232 A list or iterable of tasks. Each task can be itself an iterable 

233 (e.g., tuple) of values or data to pass in to the worker function. 

234 callback : callable, optional 

235 An optional callback function (or callable) that is called with the 

236 result from each worker run and is executed on the master process. 

237 This is useful for, e.g., saving results to a file, since the 

238 callback is only called on the master thread. 

239 

240 Returns 

241 ------- 

242 results : list 

243 A list of results from the output of each ``worker()`` call. 

244 """ 

245 

246 # If not the master just wait for instructions. 

247 if not self.is_master(): 

248 self.wait() 

249 return 

250 

251 if callback is None: 

252 callback = _dummy_callback 

253 

254 workerset = self.workers.copy() 

255 tasklist = [(tid, (worker, arg)) for tid, arg in enumerate(tasks)] 

256 resultlist = [None] * len(tasklist) 

257 pending = len(tasklist) 

258 

259 # Buffers for each worker (worker index starts from 1) 

260 reqlist = [None] * len(workerset) 

261 taskbuffer = [None] * len(workerset) 

262 

263 self.flag_parallel() 

264 

265 while pending: 

266 if workerset and tasklist: 

267 worker = workerset.pop() 

268 ibuf = worker - 1 

269 taskid, taskbuffer[ibuf] = tasklist.pop() 

270 log.log( 

271 _VERBOSE, 

272 "Sent task %s to worker %s with tag %s", 

273 taskbuffer[ibuf][1], 

274 worker, 

275 taskid, 

276 ) 

277 # Create send request - no need to test because result return is a sufficient indicator 

278 reqlist[ibuf] = self.comm.isend( 

279 taskbuffer[ibuf], dest=worker, tag=taskid 

280 ) 

281 if not self.parallel_comms: 

282 reqlist[ibuf].wait() 

283 

284 if tasklist: 

285 flag = self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG) 

286 if not flag: 

287 continue 

288 else: 

289 self.comm.Probe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG) 

290 

291 status = MPI.Status() 

292 result = self.comm.recv( 

293 source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status 

294 ) 

295 worker = status.source 

296 taskid = status.tag 

297 log.log( 

298 _VERBOSE, "Master received from worker %s with tag %s", worker, taskid 

299 ) 

300 

301 callback(result) 

302 

303 workerset.add(worker) 

304 resultlist[taskid] = result 

305 pending -= 1 

306 

307 self.flag_serial() 

308 

309 return resultlist 

310 

311 def close(self): 

312 """When master task is done, tidy up.""" 

313 # Only kill workers if pool is open, otherwise a leftover 

314 # MPI message will remain and kill the next pool that opens 

315 if not self.pool_open: 

316 raise RuntimeError( 

317 "Attempting to close schwimmbad pool that has already been closed" 

318 ) 

319 

320 self.pool_open = False 

321 if self.is_master(): 

322 self.kill_workers() 

323 

324 if self.time_mpi: 

325 self.timer.parallel_total() 

326 

327 def kill_workers(self): 

328 """Tell all the workers to quit.""" 

329 buf = None 

330 for worker in self.workers: 

331 self.comm.send(buf, dest=worker, tag=0) 

332 

333 def flag_serial(self): 

334 """Tell all the workers that serial code is being executed.""" 

335 if self.time_mpi: 

336 buf = "s" 

337 for worker in self.workers: 

338 self.comm.send(buf, dest=worker, tag=0) 

339 

340 def flag_parallel(self): 

341 """Tell all the workers that serial code has finished.""" 

342 if self.time_mpi: 

343 buf = "p" 

344 for worker in self.workers: 

345 self.comm.send(buf, dest=worker, tag=0) 

346 

347 

348class Timer: 

349 def __init__(self, rank, comm, master): 

350 

351 self.rank = rank 

352 self.comm = comm 

353 self.master = master 

354 

355 self.cumulative_time = {} 

356 self.interval_time = {} 

357 self.start_time = {} 

358 self.total = {} 

359 

360 self.group = [ 

361 "master_serial", 

362 "mpi_recv", 

363 "compute", 

364 "mpi_send", 

365 "barrier", 

366 "walltime", 

367 ] 

368 

369 self.reset_all() 

370 

371 def start(self, name): 

372 self.start_time[name] = timeit.time.perf_counter() 

373 

374 def stop(self, name): 

375 now = timeit.time.perf_counter() 

376 dt = now - self.start_time[name] 

377 self.interval_time[name] += dt 

378 self.cumulative_time[name] += dt 

379 

380 def reset_all(self): 

381 for name in self.group: 

382 self.start_time[name] = 0 

383 self.cumulative_time[name] = 0 

384 self.reset() 

385 

386 def reset(self): 

387 for name in self.group: 

388 self.interval_time[name] = 0 

389 

390 def parallel_total(self): 

391 if self.rank == self.master: 

392 for name in self.group: 

393 self.total[name] = 0 

394 

395 status = MPI.Status() 

396 for isrc in range(1, self.comm.Get_size()): 

397 times = self.comm.recv(source=isrc, tag=1, status=status) 

398 for name in self.group: 

399 self.total[name] += times[name] 

400 

401 print("MPI Timer -- cumulative wall time of each task") 

402 all = 0 

403 for name in self.group: 

404 if name == "walltime": 

405 continue 

406 all += self.total[name] 

407 for name in self.group: 

408 str_time = str(datetime.timedelta(seconds=self.total[name])) 

409 str_percent = f"{100*self.total[name] / all:.2f}%" 

410 print(f" {name: <16}: {str_time: <10} ({str_percent: <5})") 

411 print(f" Total time: {str(datetime.timedelta(seconds=all))} ({all:.2f} s)") 

412 

413 else: 

414 self.comm.send(self.cumulative_time, dest=self.master, tag=1) 

415 

416 

417class NullTimer(Timer): 

418 def __init__(self): 

419 return 

420 

421 def start(self, name): 

422 return 

423 

424 def stop(self, name): 

425 return 

426 

427 def reset(self): 

428 return 

429 

430 def reset_all(self): 

431 return 

432 

433 def parallel_total(self): 

434 return 

435 

436 def __str__(self): 

437 return ""