Source code for yank.mpi

#!/usr/bin/env python

# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================

"""
MPI
===

Utilities to run on MPI.

Provide functions and decorators that simplify running the same code on
multiple nodes. One benefit is that serial and parallel code is exactly
the same.

Global variables
----------------
disable_mpi : bool
    Set this to True to force running serially.

Routines
--------
:func:`get_mpicomm`
    Automatically detect and configure MPI execution and return an
    MPI communicator.
:func:`run_single_node`
    Run a task on a single node.
:func:`on_single_node`
    Decorator version of :func:`run_single_node`.
:func:`distribute`
    Map a task on a sequence of arguments on all the nodes.
:func:`delay_termination`
    A context manager to delay the response to termination signals.
:func:`delayed_termination`
    A decorator version of :func:`delay_termination`.

"""


# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================

import os
import sys
import signal
import logging
from contextlib import contextmanager

import numpy as np
# TODO drop this when we drop Python 2 support
from openmoltools.utils import wraps_py2

logger = logging.getLogger(__name__)

# ==============================================================================
# GLOBAL VARIABLES
# ==============================================================================

disable_mpi = False


# ==============================================================================
# MAIN FUNCTIONS
# ==============================================================================

[docs]def get_mpicomm(): """Retrieve the MPI communicator for this execution. The function automatically detects if the program runs on MPI by checking specific environment variables set by various MPI implementations. On first execution, it modifies sys.excepthook and register a handler for SIGINT, SIGTERM, SIGABRT to call MPI's ``Abort()`` to correctly terminate all processes. Returns ------- mpicomm : mpi4py communicator or None The communicator for this node, None if the program doesn't run with MPI. """ # If MPI execution is forcefully disabled, return None. if disable_mpi: return None # If we have already initialized MPI, return the cached MPI communicator. if get_mpicomm._is_initialized: return get_mpicomm._mpicomm # Check for environment variables set by mpirun. Variables are from # http://docs.roguewave.com/threadspotter/2012.1/linux/manual_html/apas03.html variables = ['PMI_RANK', 'OMPI_COMM_WORLD_RANK', 'OMPI_MCA_ns_nds_vpid', 'PMI_ID', 'SLURM_PROCID', 'LAMRANK', 'MPI_RANKID', 'MP_CHILD', 'MP_RANK', 'MPIRUN_RANK'] use_mpi = False for var in variables: if var in os.environ: use_mpi = True break # Return None if we are not running on MPI. if not use_mpi: logger.debug('Cannot find MPI environment. MPI disabled.') get_mpicomm._mpicomm = None get_mpicomm._is_initialized = True return get_mpicomm._mpicomm # Initialize MPI from mpi4py import MPI mpicomm = MPI.COMM_WORLD # Override sys.excepthook to abort MPI on exception def mpi_excepthook(type, value, traceback): sys.__excepthook__(type, value, traceback) sys.stdout.flush() sys.stderr.flush() if mpicomm.size > 1: mpicomm.Abort(1) # Use our eception handler sys.excepthook = mpi_excepthook # Catch sigterm signals def handle_signal(signal, frame): if mpicomm.size > 1: mpicomm.Abort(1) for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGABRT]: signal.signal(sig, handle_signal) # Cache and return the MPI communicator. get_mpicomm._is_initialized = True get_mpicomm._mpicomm = mpicomm # Report initialization logger.debug("MPI initialized on node {}/{}".format(mpicomm.rank+1, mpicomm.size)) return mpicomm
get_mpicomm._is_initialized = False # Static variable
[docs]def run_single_node(rank, task, *args, **kwargs): """Run task on a single node. If MPI is not activated, this simply runs locally. Parameters ---------- task : callable The task to run on node rank. rank : int The rank of the MPI communicator that must execute the task. broadcast_result : bool, optional If True, the result is broadcasted to all nodes. If False, only the node executing the task will receive the return value of the task, and all other nodes will receive None (default is False). sync_nodes : bool, optional If True, the nodes will be synchronized at the end of the execution (i.e. the task will be blocking) even if the result is not broadcasted (default is False). Other Parameters ---------------- args The ordered arguments to pass to task. kwargs The keyword arguments to pass to task. Returns ------- result The return value of the task. This will be None on all nodes that is not the rank unless ``broadcast_result`` is set to True. Examples -------- >>> def add(a, b): ... return a + b >>> # Run 3+4 on node 0. >>> run_single_node(0, task=add, a=3, b=4, broadcast_result=True) 7 """ broadcast_result = kwargs.pop('broadcast_result', False) sync_nodes = kwargs.pop('sync_nodes', False) result = None mpicomm = get_mpicomm() if mpicomm is not None: node_name = 'Node {}/{}'.format(mpicomm.rank+1, mpicomm.size) else: node_name = 'Single node' # Execute the task only on the specified node. if mpicomm is None or mpicomm.rank == rank: logger.debug('{}: executing {}'.format(node_name, task)) result = task(*args, **kwargs) # Broadcast the result if required. if mpicomm is not None: if broadcast_result is True: logger.debug('{}: waiting for broadcast of {}'.format(node_name, task)) result = mpicomm.bcast(result, root=rank) elif sync_nodes is True: logger.debug('{}: waiting for barrier after {}'.format(node_name, task)) mpicomm.barrier() # Return result. return result
[docs]def on_single_node(rank, broadcast_result=False, sync_nodes=False): """A decorator version of run_single_node. Decorates a function to be always executed with :func:`run_single_node`. Parameters ---------- rank : int The rank of the MPI communicator that must execute the task. broadcast_result : bool, optional If True the result is broadcasted to all nodes. If False, only the node executing the function will receive its return value, and all other nodes will receive None (default is False). sync_nodes : bool, optional If True, the nodes will be synchronized at the end of the execution (i.e. the task will be blocking) even if the result is not broadcasted (default is False). See Also -------- run_single_node Examples -------- >>> @on_single_node(rank=0, broadcast_result=True) ... def add(a, b): ... return a + b >>> add(3, 4) 7 """ def _on_single_node(task): @wraps_py2(task) def _wrapper(*args, **kwargs): kwargs['broadcast_result'] = broadcast_result kwargs['sync_nodes'] = sync_nodes return run_single_node(rank, task, *args, **kwargs) return _wrapper return _on_single_node
[docs]def distribute(task, distributed_args, *other_args, **kwargs): """Map the task on a sequence of arguments to be executed on different nodes. If MPI is not activated, this simply runs serially on this node. The algorithm guarantees that each node will be assigned to the same job_id (i.e. the index of the argument in ``distributed_args``) every time. Parameters ---------- task : callable The task to be distributed among nodes. The task will be called as ``task(distributed_args[job_id], *other_args, **kwargs)``, so the parameter to be distributed must the the first one. distributed_args : iterable The sequence of the parameters to distribute among nodes. send_results_to : int or 'all', optional If the string 'all', the result will be sent to all nodes. If an int, the result will be send only to the node with rank ``send_results_to``. The return value of distribute depends on the value of this parameter (default is None). sync_nodes : bool, optional If True, the nodes will be synchronized at the end of the execution (i.e. the task will be blocking) even if the result is not shared (default is False). group_nodes : None, int or list of int, optional, default is None If not None, the ``distributed_args`` are distributed among groups of nodes that are isolated from each other. This is particularly useful if ``task`` also calls :func:`distribute`, since normally that would result in unexpected behavior. If an integer, the nodes are split into equal groups of ``group_nodes`` nodes. If a list of integers, the nodes are split in possibly unequal groups (see example below). Other Parameters ---------------- other_args Other parameters to pass to task beside the assigned distributed parameters. kwargs Keyword arguments to pass to task beside the assigned distributed parameters. Returns ------- all_results : list All the return values for all the arguments if the results where sent to the node, or only the return values of the arguments processed by this node otherwise. arg_indices : list of int, optional This is returned as part of a tuple ``(all_results, job_indices)`` only if ``send_results_to`` is set to an int or None. In this case ``all_results[i]`` is the return value of ``task(all_args[arg_indices[i]])``. Examples -------- >>> def square(x): ... return x**2 >>> distribute(square, [1, 2, 3, 4], send_results_to='all') [1, 4, 9, 16] When send_results_to is not set to `all`, the return value include also the indices of the arguments associated to the result. >>> distribute(square, [1, 2, 3, 4], send_results_to=0) ([1, 4, 9, 16], [0, 1, 2, 3]) Divide the nodes in two groups of 2. The task, in turn, can distribute another task among the nodes in its own group. >>> def supertask(list_of_bases): ... return distribute(square, list_of_bases, send_results_to='all') >>> list_of_supertask_args = [[1, 2, 3], [4], [5, 6]] >>> distribute(supertask, distributed_args=list_of_supertask_args, ... send_results_to='all', group_nodes=2) [[1, 4, 9], [16], [25, 36]] """ send_results_to = kwargs.pop('send_results_to', None) sync_nodes = kwargs.pop('sync_nodes', False) group_nodes = kwargs.pop('group_nodes', None) mpicomm = get_mpicomm() n_jobs = len(distributed_args) # If MPI is not activated, just run serially. if mpicomm is None: logger.debug('Running {} serially.'.format(task.__name__)) all_results = [task(job_args, *other_args, **kwargs) for job_args in distributed_args] if send_results_to == 'all': return all_results else: return all_results, list(range(n_jobs)) # Determine the jobs that this node has to run. # If we need to group nodes, split the default mpicomm. if group_nodes is not None: # Store original mpicomm that we'll have to restore later. original_mpicomm = mpicomm # Determine the color of this node. try: # Check if this is an integer. color = int(mpicomm.rank / group_nodes) n_groups = int(np.ceil(mpicomm.size / group_nodes)) except TypeError: # List of integers. # Check that the group division requested make sense. cumulative_sum_nodes = np.cumsum(group_nodes) if cumulative_sum_nodes[-1] != mpicomm.size: raise ValueError('The group division requested cannot be performed.\n' 'Total number of nodes: {}\n' 'Group nodes: {}'.format(mpicomm.size, group_nodes)) # The first group_nodes[0] nodes have color 0, the next group_nodes[1] nodes # have color 1 etc. color = next(i for i, v in enumerate(cumulative_sum_nodes) if v > mpicomm.rank) n_groups = len(group_nodes) # Split the mpicomm among nodes. Maintain same order using mpicomm.rank as rank. mpicomm = original_mpicomm.Split(color=color, key=mpicomm.rank) # Cache new mpicomm so that task() will access the split mpicomm. get_mpicomm._mpicomm = mpicomm # Distribute distributed_args by color. node_job_ids = range(color, n_jobs, n_groups) node_name = 'Group {}/{}, Node {}/{}'.format(color+1, n_groups, mpicomm.rank+1, mpicomm.size) else: # Distribute distributed_args by mpicomm.rank. node_job_ids = range(mpicomm.rank, n_jobs, mpicomm.size) node_name = 'Node {}/{}'.format(mpicomm.rank+1, mpicomm.size) # Compute all the results assigned to this node. results = [] for job_id in node_job_ids: distributed_arg = distributed_args[job_id] logger.debug('{}: execute {}({})'.format(node_name, task.__name__, distributed_arg)) results.append(task(distributed_arg, *other_args, **kwargs)) # If we have split the mpicomm, nodes belonging to the same group # have duplicate results. We gather only results from one node. if not group_nodes or mpicomm.rank == 0: results_to_send = results else: results_to_send = [] # Restore the original mpicomm. if group_nodes is not None: mpicomm.Free() mpicomm = original_mpicomm get_mpicomm._mpicomm = original_mpicomm # Share result as specified. if send_results_to == 'all': logger.debug('{}: allgather results of {}'.format(node_name, task.__name__)) all_results = mpicomm.allgather(results_to_send) elif isinstance(send_results_to, int): logger.debug('{}: gather results of {}'.format(node_name, task.__name__)) all_results = mpicomm.gather(results_to_send, root=send_results_to) # If this is not the receiving node, we can safely return. if mpicomm.rank != send_results_to: return results, list(node_job_ids) else: assert send_results_to is None # Safety check. if sync_nodes is True: logger.debug('{}: waiting for barrier after {}'.format(node_name, task.__name__)) mpicomm.barrier() return results, list(node_job_ids) # all_results is a list of list of results. The internal lists of # results are ordered by rank. We need to reorder the results as a # flat list or results ordered by job_id. # job_indices[job_id] is the tuple of indices (rank, i). The result # of job_id is stored in all_results[rank][i]. job_indices = [] max_jobs_per_node = max([len(r) for r in all_results]) for i in range(max_jobs_per_node): for rank in range(mpicomm.size): # Not all nodes have executed max_jobs_per_node tasks. if len(all_results[rank]) > i: job_indices.append((rank, i)) # Reorder the results. all_results = [all_results[rank][i] for rank, i in job_indices] # Return result. if send_results_to == 'all': return all_results else: return all_results, list(range(n_jobs))
@contextmanager
[docs]def delay_termination(): """Context manager to delay handling of termination signals. This allows to avoid interrupting tasks such as writing to the file system, which could result in the corruption of the file. """ signals_to_catch = [signal.SIGINT, signal.SIGTERM, signal.SIGABRT] old_handlers = {signum: signal.getsignal(signum) for signum in signals_to_catch} signals_received = {signum: None for signum in signals_to_catch} def delay_handler(signum, frame): signals_received[signum] = (signum, frame) # Set handlers fot delay for signum in signals_to_catch: signal.signal(signum, delay_handler) yield # Resume program # Restore old handlers for signum, handler in old_handlers.items(): signal.signal(signum, handler) # Fire delayed signals for signum, s in signals_received.items(): if s is not None: old_handlers[signum](*s)
[docs]def delayed_termination(func): """Decorator that runs the function with :func:`delay_termination`.""" @wraps_py2(func) def _delayed_termination(*args, **kwargs): with delay_termination(): return func(*args, **kwargs) return _delayed_termination
# ============================================================================== # MAIN AND TESTS # ============================================================================== if __name__ == "__main__": import doctest doctest.testmod()