#!/usr/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
MPI
===
Utilities to run on MPI.
Provide functions and decorators that simplify running the same code on
multiple nodes. One benefit is that serial and parallel code is exactly
the same.
Global variables
----------------
disable_mpi : bool
Set this to True to force running serially.
Routines
--------
:func:`get_mpicomm`
Automatically detect and configure MPI execution and return an
MPI communicator.
:func:`run_single_node`
Run a task on a single node.
:func:`on_single_node`
Decorator version of :func:`run_single_node`.
:func:`distribute`
Map a task on a sequence of arguments on all the nodes.
:func:`delay_termination`
A context manager to delay the response to termination signals.
:func:`delayed_termination`
A decorator version of :func:`delay_termination`.
"""
# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================
import functools
import logging
import os
import sys
import signal
from contextlib import contextmanager
from traceback import format_exception
import numpy as np
logger = logging.getLogger(__name__)
# ==============================================================================
# GLOBAL VARIABLES
# ==============================================================================
# Force serial execution even in MPI environment.
disable_mpi = False
# A dummy MPI communicator used to simulate an MPI environment in tests.
_simulated_mpicomm = None
# ==============================================================================
# MAIN FUNCTIONS
# ==============================================================================
[docs]def get_mpicomm():
"""Retrieve the MPI communicator for this execution.
The function automatically detects if the program runs on MPI by checking
specific environment variables set by various MPI implementations. On
first execution, it modifies sys.excepthook and register a handler for
SIGINT, SIGTERM, SIGABRT to call MPI's ``Abort()`` to correctly terminate all
processes.
Returns
-------
mpicomm : mpi4py communicator or None
The communicator for this node, None if the program doesn't run
with MPI.
"""
# If MPI execution is forcefully disabled, return None.
if disable_mpi:
return None
# If MPI is simulated, return the Dummy implementation.
if _simulated_mpicomm is not None:
return _simulated_mpicomm
# If we have already initialized MPI, return the cached MPI communicator.
if get_mpicomm._is_initialized:
return get_mpicomm._mpicomm
# Check for environment variables set by mpirun. Variables are from
# http://docs.roguewave.com/threadspotter/2012.1/linux/manual_html/apas03.html
variables = ['PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_MCA_ns_nds_vpid',
'PMI_ID', 'SLURM_PROCID', 'LAMRANK', 'MPI_RANKID',
'MP_CHILD', 'MP_RANK', 'MPIRUN_RANK',
'ALPS_APP_PE', # Cray aprun
]
use_mpi = False
for var in variables:
if var in os.environ:
use_mpi = True
break
# Return None if we are not running on MPI.
if not use_mpi:
logger.debug('Cannot find MPI environment. MPI disabled.')
get_mpicomm._mpicomm = None
get_mpicomm._is_initialized = True
return get_mpicomm._mpicomm
# Initialize MPI
from mpi4py import MPI
mpicomm = MPI.COMM_WORLD
# Override sys.excepthook to abort MPI on exception.
def mpi_excepthook(type, value, traceback):
sys.__excepthook__(type, value, traceback)
node_name = '{}/{}'.format(MPI.COMM_WORLD.rank+1, MPI.COMM_WORLD.size)
# logging.exception() automatically print the sys.exc_info(), but here
# we may want to save the exception traceback of another MPI node so
# we pass the traceback manually.
logger.critical('MPI node {} raised an exception and called Abort()! The '
'exception traceback follows'.format(node_name), exc_info=value)
# Flush everything.
sys.stdout.flush()
sys.stderr.flush()
for logger_handler in logger.handlers:
logger_handler.flush()
# Abort MPI execution.
if MPI.COMM_WORLD.size > 1:
MPI.COMM_WORLD.Abort(1)
# Use our exception handler.
sys.excepthook = mpi_excepthook
# Catch sigterm signals
def handle_signal(signal, frame):
if mpicomm.size > 1:
mpicomm.Abort(1)
for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]:
signal.signal(sig, handle_signal)
# Cache and return the MPI communicator.
get_mpicomm._is_initialized = True
get_mpicomm._mpicomm = mpicomm
# Report initialization
logger.debug("MPI initialized on node {}/{}".format(mpicomm.rank+1, mpicomm.size))
return mpicomm
get_mpicomm._is_initialized = False # Static variable
[docs]def run_single_node(rank, task, *args, **kwargs):
"""Run task on a single node.
If MPI is not activated, this simply runs locally.
Parameters
----------
task : callable
The task to run on node rank.
rank : int
The rank of the MPI communicator that must execute the task.
broadcast_result : bool, optional
If True, the result is broadcasted to all nodes. If False,
only the node executing the task will receive the return
value of the task, and all other nodes will receive None
(default is False).
sync_nodes : bool, optional
If True, the nodes will be synchronized at the end of the
execution (i.e. the task will be blocking) even if the
result is not broadcasted (default is False).
Other Parameters
----------------
args
The ordered arguments to pass to task.
kwargs
The keyword arguments to pass to task.
Returns
-------
result
The return value of the task. This will be None on all nodes
that is not the rank unless ``broadcast_result`` is set to True.
Examples
--------
>>> def add(a, b):
... return a + b
>>> # Run 3+4 on node 0.
>>> run_single_node(0, task=add, a=3, b=4, broadcast_result=True)
7
"""
broadcast_result = kwargs.pop('broadcast_result', False)
sync_nodes = kwargs.pop('sync_nodes', False)
result = None
mpicomm = get_mpicomm()
if mpicomm is not None:
node_name = 'Node {}/{}'.format(mpicomm.rank+1, mpicomm.size)
else:
node_name = 'Single node'
# Execute the task only on the specified node.
if mpicomm is None or mpicomm.rank == rank:
logger.debug('{}: executing {}'.format(node_name, task))
result = task(*args, **kwargs)
# Broadcast the result if required.
if mpicomm is not None:
if broadcast_result is True:
logger.debug('{}: waiting for broadcast of {}'.format(node_name, task))
result = mpicomm.bcast(result, root=rank)
elif sync_nodes is True:
logger.debug('{}: waiting for barrier after {}'.format(node_name, task))
mpicomm.barrier()
# Return result.
return result
[docs]def on_single_node(rank, broadcast_result=False, sync_nodes=False):
"""A decorator version of run_single_node.
Decorates a function to be always executed with :func:`run_single_node`.
Parameters
----------
rank : int
The rank of the MPI communicator that must execute the task.
broadcast_result : bool, optional
If True the result is broadcasted to all nodes. If False,
only the node executing the function will receive its return
value, and all other nodes will receive None (default is False).
sync_nodes : bool, optional
If True, the nodes will be synchronized at the end of the
execution (i.e. the task will be blocking) even if the
result is not broadcasted (default is False).
See Also
--------
run_single_node
Examples
--------
>>> @on_single_node(rank=0, broadcast_result=True)
... def add(a, b):
... return a + b
>>> add(3, 4)
7
"""
def _on_single_node(task):
@functools.wraps(task)
def _wrapper(*args, **kwargs):
kwargs['broadcast_result'] = broadcast_result
kwargs['sync_nodes'] = sync_nodes
return run_single_node(rank, task, *args, **kwargs)
return _wrapper
return _on_single_node
class _MpiProcessingUnit(object):
"""Context manager abstracting a single MPI processes and a group of nodes.
Parameters
----------
group_size : None, int or list of int, optional, default is None
If not None, the ``distributed_args`` are distributed among groups of
nodes that are isolated from each other. If an integer, the nodes are
split into equal groups of ``group_size`` nodes. If a list of integers,
the nodes are split in possibly unequal groups.
Attributes
----------
rank : int
Either the rank of the node, or the color of the group.
size : int
Either the size of the mpicomm, or the number of groups.
is_group
"""
def __init__(self, group_size):
# Store original mpicomm that we'll have to restore later.
self._parent_mpicomm = get_mpicomm()
# No need to split the comm if group size is None.
if group_size is None:
self._exec_mpicomm = self._parent_mpicomm
self.rank = self._parent_mpicomm.rank
self.size = self._parent_mpicomm.size
else:
# Determine the color that will be assigned to this node.
node_color, n_groups = self._determine_node_color(group_size)
# Split the mpicomm among nodes. Maintain same order using mpicomm.rank as rank.
self._exec_mpicomm = self._parent_mpicomm.Split(color=node_color,
key=self._parent_mpicomm.rank)
self.rank = node_color
self.size = n_groups
@property
def is_group(self):
"""True if this is a group of nodes (i.e. :func:`get_mpicomm` is split)."""
return self._exec_mpicomm != self._parent_mpicomm
def exec_tasks(self, task, distributed_args, propagate_exceptions_to,
*other_args, **kwargs):
"""Run task on the given arguments.
Parameters
----------
propagate_exceptions_to : 'all', 'group', or None
When one of the processes raise an exception during the task
execution, this controls which other processes raise it.
Returns
-------
results : list
The list of the return values of the task. One for each argument.
"""
# Determine where to propagate exceptions.
if propagate_exceptions_to == 'all':
exception_mpicomm = self._parent_mpicomm
elif propagate_exceptions_to == 'group':
exception_mpicomm = self._exec_mpicomm
elif propagate_exceptions_to is None:
exception_mpicomm = None
else:
raise ValueError('Unknown value for propagate_exceptions_to: '
'{}'.format(propagate_exceptions_to))
# Determine name for logging.
node_name = 'Node {}/{}'.format(self._exec_mpicomm.rank+1, self._exec_mpicomm.size)
if self.is_group:
node_name = 'Group {}/{} '.format(self.rank+1, self.size) + node_name
# Compute all the results assigned to this node.
results = []
error = None
for distributed_arg in distributed_args:
logger.debug('{}: execute {}({})'.format(node_name, task.__name__, distributed_arg))
try:
results.append(task(distributed_arg, *other_args, **kwargs))
except Exception as e:
# Create an exception with same type and traceback but with node info.
error = type(e)('{}: {}'.format(node_name, str(e)))
error.with_traceback(e.__traceback__)
# When sending the error over the network, the traceback seems to be lost,
# so we create a string version of it, and expose it for others to print.
traceback_str = ''.join(format_exception(type(e), e, e.__traceback__))
error.traceback_str = traceback_str
break
# Propagate eventual exceptions to other nodes before raising.
all_errors = []
if exception_mpicomm is not None:
all_errors = exception_mpicomm.allgather(error)
all_errors = [e for e in all_errors if e is not None]
# Each node raises its own exception first and then the others
# (if any). This way the logs will be more informative.
if error is not None:
raise error
elif len(all_errors) > 0:
# Raise the first error received from a different MPI process.
external_error = all_errors[0]
# Include original traceback in the error message (indented 4 spaces).
traceback_str = '\n '.join(external_error.traceback_str.split('\n'))
err_msg = ('{} received an exception from another MPI process. Original'
' stack trace follow:\n{}').format(node_name, traceback_str)
error = type(external_error)(err_msg)
error.with_traceback(external_error.__traceback__)
raise error
return results
def __enter__(self):
# Cache execution mpicomm so that tasks will access the split mpicomm.
get_mpicomm._mpicomm = self._exec_mpicomm
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
# Restore the original mpicomm.
if self.is_group:
self._exec_mpicomm.Free()
get_mpicomm._mpicomm = self._parent_mpicomm
def _determine_node_color(self, group_size):
"""Determine the color of this node."""
try: # Check if this is an integer.
node_color = int(self._parent_mpicomm.rank / group_size)
n_groups = int(np.ceil(self._parent_mpicomm.size / group_size))
except TypeError: # List of integers.
# Check that the group division requested make sense. The sum
# of all group sizes must be equal to the size of the mpicomm.
cumulative_sum_nodes = np.cumsum(group_size)
if cumulative_sum_nodes[-1] != self._parent_mpicomm.size:
raise ValueError('The group division requested cannot be performed.\n'
'Total number of nodes: {}\n'
'Group nodes: {}'.format(self._parent_mpicomm.size, group_size))
# The first group_size[0] nodes have color 0, the next group_size[1]
# nodes have color 1 etc.
node_color = next(i for i, v in enumerate(cumulative_sum_nodes)
if v > self._parent_mpicomm.rank)
n_groups = len(group_size)
return node_color, n_groups
[docs]def distribute(task, distributed_args, *other_args, send_results_to='all',
propagate_exceptions_to='all', sync_nodes=False, group_size=None, **kwargs):
"""Map the task on a sequence of arguments to be executed on different nodes.
If MPI is not activated, this simply runs serially on this node. The
algorithm guarantees that each node will be assigned to the same job_id
(i.e. the index of the argument in ``distributed_args``) every time.
Parameters
----------
task : callable
The task to be distributed among nodes. The task will be called as
``task(distributed_args[job_id], *other_args, **kwargs)``, so the parameter
to be distributed must the the first one.
distributed_args : iterable
The sequence of the parameters to distribute among nodes.
send_results_to : int, 'all', or None, optional
If the string 'all', the result will be sent to all nodes. If an
int, the result will be send only to the node with rank ``send_results_to``.
The return value of distribute depends on the value of this parameter
(default is None).
propagate_exceptions_to : 'all', 'group', or None, optional
When one of the processes raise an exception during the task execution,
this controls which other processes raise it (default is 'all'). This
can be 'group' or None only if ``send_results_to`` is None.
sync_nodes : bool, optional
If True, the nodes will be synchronized at the end of the
execution (i.e. the task will be blocking) even if the
result is not shared (default is False).
group_size : None, int or list of int, optional, default is None
If not None, the ``distributed_args`` are distributed among groups of
nodes that are isolated from each other. This is particularly useful
if ``task`` also calls :func:`distribute`, since normally that would result
in unexpected behavior.
If an integer, the nodes are split into equal groups of ``group_size``
nodes. If ``n_nodes % group_size != 0``, the first jobs are allocated
more nodes than the latest. If a list of integers, the nodes are
split in possibly unequal groups (see example below).
Other Parameters
----------------
other_args
Other parameters to pass to task beside the assigned distributed
parameters.
kwargs
Keyword arguments to pass to task beside the assigned distributed
parameters.
Returns
-------
all_results : list
All the return values for all the arguments if the results where sent
to the node, or only the return values of the arguments processed by
this node otherwise.
arg_indices : list of int, optional
This is returned as part of a tuple ``(all_results, job_indices)`` only
if ``send_results_to`` is set to an int or None. In this case ``all_results[i]``
is the return value of ``task(all_args[arg_indices[i]])``.
Examples
--------
>>> def square(x):
... return x**2
>>> distribute(square, [1, 2, 3, 4], send_results_to='all')
[1, 4, 9, 16]
When send_results_to is not set to `all`, the return value include also
the indices of the arguments associated to the result.
>>> distribute(square, [1, 2, 3, 4], send_results_to=0)
([1, 4, 9, 16], [0, 1, 2, 3])
Divide the nodes in two groups of 2. The task, in turn, can
distribute another task among the nodes in its own group.
>>> def supertask(list_of_bases):
... return distribute(square, list_of_bases, send_results_to='all')
>>> list_of_supertask_args = [[1, 2, 3], [4], [5, 6]]
>>> distribute(supertask, distributed_args=list_of_supertask_args,
... send_results_to='all', group_size=2)
[[1, 4, 9], [16], [25, 36]]
"""
n_jobs = len(distributed_args)
# If MPI is not activated, just run serially.
if get_mpicomm() is None:
logger.debug('Running {} serially.'.format(task.__name__))
all_results = [task(job_args, *other_args, **kwargs) for job_args in distributed_args]
if send_results_to == 'all':
return all_results
else:
return all_results, list(range(n_jobs))
# We can't propagate exceptions to a subset of nodes if we need to send all the results.
if send_results_to is not None and propagate_exceptions_to != 'all':
raise ValueError('Cannot propagate exceptions to a subset of nodes '
'with send_results_to != None')
# Split the default mpicomm into group if necessary.
with _MpiProcessingUnit(group_size) as processing_unit:
# Determine the jobs that this node has to run.
node_job_ids = range(processing_unit.rank, n_jobs, processing_unit.size)
node_distributed_args = [distributed_args[job_id] for job_id in node_job_ids]
# Run all jobs.
results = processing_unit.exec_tasks(task, node_distributed_args, propagate_exceptions_to,
*other_args, **kwargs)
# If we have split the mpicomm, nodes belonging to the same group
# have duplicate results. We gather only results from one node.
if processing_unit.is_group and get_mpicomm().rank != 0:
results_to_send = []
else:
results_to_send = results
# Share result as requested.
mpicomm = get_mpicomm()
node_name = 'Node {}/{}'.format(mpicomm.rank+1, mpicomm.size)
if send_results_to == 'all':
logger.debug('{}: allgather results of {}'.format(node_name, task.__name__))
all_results = mpicomm.allgather(results_to_send)
elif isinstance(send_results_to, int):
logger.debug('{}: sending results of {} to {}'.format(node_name, task.__name__,
send_results_to))
all_results = mpicomm.gather(results_to_send, root=send_results_to)
# If this is not the receiving node, we can safely return.
if mpicomm.rank != send_results_to:
return results, list(node_job_ids)
else:
assert send_results_to is None # Safety check.
if sync_nodes is True:
logger.debug('{}: waiting for barrier after {}'.format(node_name, task.__name__))
mpicomm.barrier()
return results, list(node_job_ids)
# all_results is a list of list of results. The internal lists of
# results are ordered by rank. We need to reorder the results as a
# flat list of results ordered by job_id.
# job_indices[job_id] is the tuple of indices (rank, i). The result
# of job_id is stored in all_results[rank][i].
job_indices = []
max_jobs_per_node = max([len(r) for r in all_results])
for i in range(max_jobs_per_node):
for rank in range(mpicomm.size):
# Not all nodes have executed max_jobs_per_node tasks.
if len(all_results[rank]) > i:
job_indices.append((rank, i))
# Reorder the results.
all_results = [all_results[rank][i] for rank, i in job_indices]
# Return result.
if send_results_to == 'all':
return all_results
else:
return all_results, list(range(n_jobs))
@contextmanager
[docs]def delay_termination():
"""Context manager to delay handling of termination signals.
This allows to avoid interrupting tasks such as writing to the file
system, which could result in the corruption of the file.
"""
signals_to_catch = [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]
old_handlers = {signum: signal.getsignal(signum) for signum in signals_to_catch}
signals_received = {signum: None for signum in signals_to_catch}
def delay_handler(signum, frame):
signals_received[signum] = (signum, frame)
# Set handlers fot delay
for signum in signals_to_catch:
signal.signal(signum, delay_handler)
yield # Resume program
# Restore old handlers
for signum, handler in old_handlers.items():
signal.signal(signum, handler)
# Fire delayed signals
for signum, s in signals_received.items():
if s is not None:
old_handlers[signum](*s)
[docs]def delayed_termination(func):
"""Decorator that runs the function with :func:`delay_termination`."""
@functools.wraps(func)
def _delayed_termination(*args, **kwargs):
with delay_termination():
return func(*args, **kwargs)
return _delayed_termination
# ==============================================================================
# MPI TEST CLASSES
# ==============================================================================
class _DummyMPIComm():
"""A Dummy MPI Communicator."""
def __init__(self, rank=0, size=4):
self.rank = rank
self.size = size
@contextmanager
def _simulated_mpi_environment(**kwargs):
"""Context manager to temporarily set a simulated MPI environment.
Parameters
----------
**kwargs : dict
The parameters to pass to _DummyMPIComm constructor.
"""
global _simulated_mpicomm
old_simulated_mpicomm = _simulated_mpicomm
_simulated_mpicomm = _DummyMPIComm(**kwargs)
try:
yield
finally:
_simulated_mpicomm = old_simulated_mpicomm
# ==============================================================================
# MAIN AND TESTS
# ==============================================================================
if __name__ == "__main__":
import doctest
doctest.testmod()