Coverage for parallel_bilby/schwimmbad_fast.py: 92%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import atexit
2import datetime
3import json
4import sys
5import timeit
6import traceback
8from schwimmbad import _VERBOSE, MPIPool, log
11def _dummy_callback(x):
12 pass
15def _import_mpi(quiet=False, use_dill=False):
16 global MPI
17 import mpi4py
19 mpi4py.rc.threads = False
20 mpi4py.rc.recv_mprobe = False
21 from mpi4py import MPI as _MPI
23 if use_dill:
24 import dill
26 _MPI.pickle.__init__(dill.dumps, dill.loads, dill.HIGHEST_PROTOCOL)
27 MPI = _MPI
29 return MPI
32class MPIPoolFast(MPIPool):
33 """A processing pool with persistent MPI tasks.
35 Schwimmbad's MPI Pool starts the worker threads waiting in __init__
36 but then finally does sys.exit(0), so those threads never get a
37 chance to do anything else.
39 This fix will behave like MPIPool as default, but using the
40 parameters, the MPI worker tasks can be allowed to persist
41 beyond the pool.
42 """
44 def __init__(
45 self,
46 comm=None,
47 use_dill=False,
48 begin_wait=True,
49 persistent_tasks=True,
50 parallel_comms=False,
51 time_mpi=False,
52 timing_interval=0,
53 ):
54 MPI = _import_mpi(use_dill=use_dill)
56 if comm is None:
57 comm = MPI.COMM_WORLD
58 self.comm = comm
60 self.master = 0
61 self.rank = self.comm.Get_rank()
63 self.pool_open = True
64 atexit.register(lambda: MPIPool.close(self))
66 # Option to enable parallel communication
67 self.parallel_comms = parallel_comms
69 # Initialise timer
70 self.time_mpi = time_mpi
71 if self.time_mpi:
72 self.timer = Timer(self.rank, self.comm, self.master)
73 else:
74 self.timer = NullTimer()
76 # Periodically save the timing output on only the first worker task (specify in seconds)
77 if self.rank == 1:
78 self.timing_interval = timing_interval
79 else:
80 self.timing_interval = 0
82 if self.timing_interval == 0:
83 self.timing_interval = False
85 if not self.is_master():
86 if begin_wait:
87 # workers branch here and wait for work
88 try:
89 self.wait()
90 except Exception:
91 print(f"worker with rank {self.rank} crashed".center(80, "="))
92 traceback.print_exc()
93 sys.stdout.flush()
94 sys.stderr.flush()
95 # shutdown all mpi tasks:
96 from mpi4py import MPI
98 MPI.COMM_WORLD.Abort()
99 finally:
100 if not persistent_tasks:
101 sys.exit(0)
103 else:
104 self.workers = set(range(self.comm.size))
105 self.workers.discard(self.master)
106 self.size = self.comm.Get_size() - 1
108 if self.size == 0:
109 raise ValueError(
110 "Tried to create an MPI pool, but there "
111 "was only one MPI process available. "
112 "Need at least two."
113 )
115 @staticmethod
116 def enabled():
117 if MPI is None:
118 _import_mpi(quiet=True)
119 if MPI is not None:
120 if MPI.COMM_WORLD.size > 1:
121 return True
122 return False
124 def wait(self, callback=None):
125 """Tell the workers to wait and listen for the master process. This is
126 called automatically when using :meth:`MPIPool.map` and doesn't need to
127 be called by the user.
128 """
129 if self.is_master():
130 return
132 worker = self.comm.rank
133 status = MPI.Status()
135 if self.timing_interval:
136 time_snapshots = []
137 self.timer.start("walltime")
139 # Flag if master is performing a serial task (True by default)
140 master_serial = True
141 self.timer.start("master_serial")
143 while True:
144 # Receive task
145 self.timer.start(
146 "mpi_recv"
147 ) # recv timer only gets counted if entering into a parallel task
148 if not master_serial:
149 self.timer.start(
150 "barrier"
151 ) # start the barrier timer in case this is the last parallel task
153 log.log(_VERBOSE, f"Worker {worker} waiting for task")
154 task = self.comm.recv(source=self.master, tag=MPI.ANY_TAG, status=status)
156 # Indicator from master that a serial task is being performed
157 if task == "s":
158 self.timer.stop("barrier") # count recv time towards barrier
159 self.timer.start("master_serial")
160 master_serial = True
161 elif task == "p":
162 self.timer.stop("master_serial")
163 master_serial = False
164 else:
165 # Process task
166 if task is None:
167 log.log(_VERBOSE, f"Worker {worker} told to quit work")
169 if master_serial:
170 self.timer.stop("master_serial")
171 else:
172 self.timer.stop("barrier")
174 break
176 if master_serial and self.time_mpi:
177 print(
178 "Warning: Serial section has been flagged, but not unflagged yet. Timing will be inaccurate."
179 )
180 self.timer.stop("mpi_recv")
182 self.timer.start("compute")
183 func, arg = task
184 log.log(
185 _VERBOSE,
186 f"Worker {worker} got task {arg} with tag {status.tag}",
187 )
189 result = func(arg)
190 self.timer.stop("compute")
192 # Return results
193 self.timer.start("mpi_send")
194 log.log(
195 _VERBOSE,
196 f"Worker {worker} sending answer {result} with tag {status.tag}",
197 )
199 self.comm.ssend(result, self.master, status.tag)
200 self.timer.stop("mpi_send")
202 if self.timing_interval:
203 self.timer.stop("walltime")
204 if self.timer.interval_time["walltime"] > self.timing_interval:
205 time_snapshots += [self.timer.interval_time.copy()]
206 self.timer.reset()
207 self.timer.start("walltime")
209 if self.timing_interval:
210 with open("mpi_worker_timing.json", "w") as f:
211 json.dump(time_snapshots, f)
213 if callback is not None:
214 callback()
216 def map(self, worker, tasks, callback=None):
217 """Evaluate a function or callable on each task in parallel using MPI.
219 The callable, ``worker``, is called on each element of the ``tasks``
220 iterable. The results are returned in the expected order (symmetric with
221 ``tasks``).
223 Parameters
224 ----------
225 worker : callable
226 A function or callable object that is executed on each element of
227 the specified ``tasks`` iterable. This object must be picklable
228 (i.e. it can't be a function scoped within a function or a
229 ``lambda`` function). This should accept a single positional
230 argument and return a single object.
231 tasks : iterable
232 A list or iterable of tasks. Each task can be itself an iterable
233 (e.g., tuple) of values or data to pass in to the worker function.
234 callback : callable, optional
235 An optional callback function (or callable) that is called with the
236 result from each worker run and is executed on the master process.
237 This is useful for, e.g., saving results to a file, since the
238 callback is only called on the master thread.
240 Returns
241 -------
242 results : list
243 A list of results from the output of each ``worker()`` call.
244 """
246 # If not the master just wait for instructions.
247 if not self.is_master():
248 self.wait()
249 return
251 if callback is None:
252 callback = _dummy_callback
254 workerset = self.workers.copy()
255 tasklist = [(tid, (worker, arg)) for tid, arg in enumerate(tasks)]
256 resultlist = [None] * len(tasklist)
257 pending = len(tasklist)
259 # Buffers for each worker (worker index starts from 1)
260 reqlist = [None] * len(workerset)
261 taskbuffer = [None] * len(workerset)
263 self.flag_parallel()
265 while pending:
266 if workerset and tasklist:
267 worker = workerset.pop()
268 ibuf = worker - 1
269 taskid, taskbuffer[ibuf] = tasklist.pop()
270 log.log(
271 _VERBOSE,
272 "Sent task %s to worker %s with tag %s",
273 taskbuffer[ibuf][1],
274 worker,
275 taskid,
276 )
277 # Create send request - no need to test because result return is a sufficient indicator
278 reqlist[ibuf] = self.comm.isend(
279 taskbuffer[ibuf], dest=worker, tag=taskid
280 )
281 if not self.parallel_comms:
282 reqlist[ibuf].wait()
284 if tasklist:
285 flag = self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG)
286 if not flag:
287 continue
288 else:
289 self.comm.Probe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG)
291 status = MPI.Status()
292 result = self.comm.recv(
293 source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status
294 )
295 worker = status.source
296 taskid = status.tag
297 log.log(
298 _VERBOSE, "Master received from worker %s with tag %s", worker, taskid
299 )
301 callback(result)
303 workerset.add(worker)
304 resultlist[taskid] = result
305 pending -= 1
307 self.flag_serial()
309 return resultlist
311 def close(self):
312 """When master task is done, tidy up."""
313 # Only kill workers if pool is open, otherwise a leftover
314 # MPI message will remain and kill the next pool that opens
315 if not self.pool_open:
316 raise RuntimeError(
317 "Attempting to close schwimmbad pool that has already been closed"
318 )
320 self.pool_open = False
321 if self.is_master():
322 self.kill_workers()
324 if self.time_mpi:
325 self.timer.parallel_total()
327 def kill_workers(self):
328 """Tell all the workers to quit."""
329 buf = None
330 for worker in self.workers:
331 self.comm.send(buf, dest=worker, tag=0)
333 def flag_serial(self):
334 """Tell all the workers that serial code is being executed."""
335 if self.time_mpi:
336 buf = "s"
337 for worker in self.workers:
338 self.comm.send(buf, dest=worker, tag=0)
340 def flag_parallel(self):
341 """Tell all the workers that serial code has finished."""
342 if self.time_mpi:
343 buf = "p"
344 for worker in self.workers:
345 self.comm.send(buf, dest=worker, tag=0)
348class Timer:
349 def __init__(self, rank, comm, master):
351 self.rank = rank
352 self.comm = comm
353 self.master = master
355 self.cumulative_time = {}
356 self.interval_time = {}
357 self.start_time = {}
358 self.total = {}
360 self.group = [
361 "master_serial",
362 "mpi_recv",
363 "compute",
364 "mpi_send",
365 "barrier",
366 "walltime",
367 ]
369 self.reset_all()
371 def start(self, name):
372 self.start_time[name] = timeit.time.perf_counter()
374 def stop(self, name):
375 now = timeit.time.perf_counter()
376 dt = now - self.start_time[name]
377 self.interval_time[name] += dt
378 self.cumulative_time[name] += dt
380 def reset_all(self):
381 for name in self.group:
382 self.start_time[name] = 0
383 self.cumulative_time[name] = 0
384 self.reset()
386 def reset(self):
387 for name in self.group:
388 self.interval_time[name] = 0
390 def parallel_total(self):
391 if self.rank == self.master:
392 for name in self.group:
393 self.total[name] = 0
395 status = MPI.Status()
396 for isrc in range(1, self.comm.Get_size()):
397 times = self.comm.recv(source=isrc, tag=1, status=status)
398 for name in self.group:
399 self.total[name] += times[name]
401 print("MPI Timer -- cumulative wall time of each task")
402 all = 0
403 for name in self.group:
404 if name == "walltime":
405 continue
406 all += self.total[name]
407 for name in self.group:
408 str_time = str(datetime.timedelta(seconds=self.total[name]))
409 str_percent = f"{100*self.total[name] / all:.2f}%"
410 print(f" {name: <16}: {str_time: <10} ({str_percent: <5})")
411 print(f" Total time: {str(datetime.timedelta(seconds=all))} ({all:.2f} s)")
413 else:
414 self.comm.send(self.cumulative_time, dest=self.master, tag=1)
417class NullTimer(Timer):
418 def __init__(self):
419 return
421 def start(self, name):
422 return
424 def stop(self, name):
425 return
427 def reset(self):
428 return
430 def reset_all(self):
431 return
433 def parallel_total(self):
434 return
436 def __str__(self):
437 return ""