Source code for yank.yank

#!/usr/local/bin/env python

# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================

"""
Yank
====

Interface for automated free energy calculations.

"""

# ==============================================================================
# GLOBAL IMPORTS
# ==============================================================================

import abc
import copy
import time
import logging
import functools
import importlib
import collections

import mdtraj
import pandas
import numpy as np
import openmmtools as mmtools
from simtk import unit, openmm

from . import utils, pipeline, repex, mpi
from .restraints import RestraintState, RestraintParameterError, V0

from typing import Union, Tuple, List, Set

logger = logging.getLogger(__name__)


# ==============================================================================
# TOPOGRAPHY
# ==============================================================================

[docs]class Topography(object): """A class mapping and labelling the different components of a system. The object holds the topology of a system and offers convenience functions to identify its various parts such as solvent, receptor, ions and ligand atoms. A molecule should be labeled as a ligand, only if there is also a receptor. If there is only a single molecule its atom indices can be obtained from solute_atoms instead. In ligand-receptor system, solute_atoms provides the atom indices for both molecules. Parameters ---------- topology : mdtraj.Topology or simtk.openmm.app.Topology The topology object specifying the system. ligand_atoms : iterable of int or str, optional The atom indices of the ligand. A string is interpreted as an mdtraj DSL specification of the ligand atoms. solvent_atoms : iterable of int or str, optional The atom indices of the solvent. A string is interpreted as an mdtraj DSL specification of the solvent atoms. If 'auto', a list of common solvent residue names will be used to automatically detect solvent atoms (default is 'auto'). Attributes ---------- ligand_atoms receptor_atoms solute_atoms solvent_atoms ions_atoms """ # Built in class attributes _BUILT_IN_REGIONS = ('ligand_atoms', 'receptor_atoms', 'solute_atoms', 'solvent_atoms', 'ion_atoms') _PROTECTED_REGION_NAMES = ('and', 'or') def __init__(self, topology, ligand_atoms=None, solvent_atoms='auto'): # Determine if we need to convert the topology to mdtraj. if isinstance(topology, mdtraj.Topology): self._topology = topology else: self._topology = mdtraj.Topology.from_openmm(topology) # Initialize regions, this has to come before solvent/ligand atoms to ensure self._regions = {} # Handle default ligand atoms. if ligand_atoms is None: ligand_atoms = [] # Once ligand and solvent atoms are defined, every other region is implied. self.solvent_atoms = solvent_atoms self.ligand_atoms = ligand_atoms @property def topology(self): """mdtraj.Topology: A copy of the topology (read-only).""" return copy.deepcopy(self._topology) @property def ligand_atoms(self): """The atom indices of the ligand as list This can be empty if this :class:`Topography` doesn't represent a receptor-ligand system. Use solute_atoms to obtain the atom indices of the molecule if this is the case. If assigned to a string, it will be interpreted as an mdtraj DSL specification of the atom indices. """ return self._ligand_atoms @ligand_atoms.setter def ligand_atoms(self, value): self._ligand_atoms = self.select(value) # Safety check: with a ligand there should always be a receptor. if len(self._ligand_atoms) > 0 and len(self.receptor_atoms) == 0: raise ValueError('Specified ligand but cannot find ' 'receptor atoms. Ligand: {}'.format(value)) @property def receptor_atoms(self): """The atom indices of the receptor as list (read-only). This can be empty if this Topography doesn't represent a receptor-ligand system. Use solute_atoms to obtain the atom indices of the molecule if this is the case. """ # If there's no ligand, there's no receptor. if len(self._ligand_atoms) == 0: return [] # Create a set for fast searching. ligand_atomset = frozenset(self._ligand_atoms) # Receptor atoms are all solute atoms that are not ligand. return [i for i in self.solute_atoms if i not in ligand_atomset] @property def solute_atoms(self): """The atom indices of the non-solvent molecule(s) (read-only). Practically, this are all the indices of the atoms that are not considered solvent. In a receptor-ligand system, this includes the atom indices of both the receptor and the ligand. """ # Create a set for fast searching. solvent_atomset = frozenset(self._solvent_atoms) # The solute is everything that is not solvent. return [i for i in range(self._topology.n_atoms) if i not in solvent_atomset] @property def solvent_atoms(self): """The atom indices of the solvent molecules. This includes eventual ions. If assigned to a string, it will be interpreted as an mdtraj DSL specification of the atom indices. If assigned to 'auto', a set of solvent auto indices is automatically built from common solvent residue names. """ return self._solvent_atoms @solvent_atoms.setter def solvent_atoms(self, value): # If the user doesn't provide a solvent description, # we use a default set of resnames in mdtraj. if value == 'auto': solvent_resnames = mdtraj.core.residue_names._SOLVENT_TYPES self._solvent_atoms = [atom.index for atom in self._topology.atoms if atom.residue.name in solvent_resnames] else: self._solvent_atoms = self.select(value) @property def ions_atoms(self): """The indices of all ions atoms in the solvent (read-only).""" # Ions are all atoms of the solvent whose residue name show a charge. return [i for i in self._solvent_atoms if '-' in self._topology.atom(i).residue.name or '+' in self._topology.atom(i).residue.name]
[docs] def add_region(self, region_name, region_selection, subset=None): """ Add a region to the Topography based on a selection string of atoms. The selection accepts multiple formats such as a DSL string, a SMIRKS selection string, or hard coded atom indices. The selection string is converted to a list of atom indices. Parameters ---------- region_name : str Name of the region. This must be unique and also not the name of an existing method region_selection : str or list of ints Atom selection identifier, either a MDTraj DSL string, a SMARTS string, a compound region selection, or a hard-coded list of ints. The SMARTS string requires the OpenEye OEChem library to correctly select subset : str or list of ints, or None Atom selection sub-region to filter the atom selection through. This is a way to define your new region as a relative selection to the subset. Follows the same conditions as ``region_selection``. """ self._check_existing_regions(region_name) self._check_reserved_words(region_name) atom_selection = self.select(region_selection, subset=subset) self._regions[region_name] = atom_selection
[docs] def remove_region(self, region_name): """ Remove a previously added region from this Topography. This only affects regions added through the :func:`add_region` function. Does nothing if the region was not previously added Parameters ---------- region_name : str Name of the region to remove """ self._regions.pop(region_name, None)
[docs] def get_region(self, region_name) -> List[int]: """ Retrieve the atom indices of the given region. This function will also fetch the built-in regions Parameters ---------- region_name : str Name of the region to fetch. Can use both custom regions or built in ones such as "ligand_atoms" Returns ------- region : list of ints Atom integers which comprise the region Raises ------ KeyError If region is not part of the Topography """ if region_name not in self: raise KeyError("Cannot find region \"{}\" in this Topography.".format(region_name)) # Return the built-in if present if region_name in self._BUILT_IN_REGIONS: return getattr(self, region_name) # Return a copy to ensure people cant tweak the region outside of the api return copy.copy(self._regions[region_name])
[docs] def select(self, selection, as_set=False, sort_by='auto', subset=None) -> Union[List[int], Set[int]]: """ Select atoms based on selection format which can be a number of formats: * Single integer (a bit redundant) * Iterable of integer (also a bit redundant) * Complex Region selection * MDTraj String * SMARTS Selection This method will never return duplicate atom numbers. The ``sort_by`` method controls how the output is sorted before being returned, see the details in the Parameters block for information about each option The Complex Region string returns an the set of atoms derived from the arguments using logical operators ``and`` and ``or`` along with grouping through parenthesis . For example, assume you have two regions ``regionA = [0,1,2,3]`` and ``regionB = [2,3,4,5]``. You can do operations such as the following: ``regionA and regionB`` yields ``[2,3]``, which is the intersection of the regions. ``regionA or regionB`` yields ``[0,1,2,3,4,5]``, which is the union of the regions. More complex statements with more regions will also work, and statements can be grouped with ``()``. The ``subset`` keyword filters the ``selection`` relative to this subset selection. If not None, the subset is processed first through this same function, then the primary selection is processed relative to it. ``subset`` follows the same conditions as ``selection``, but sort order for subset is ignored If your ``selection`` would pick atoms that are NOT part of the subset, then those atoms are NOT RETURNED. If your ``selection`` is an integer or some sequence of integer, then the indices are relative to the ``subset``. Final atom numbers will be absolute to the whole Topography. Parameters ---------- selection : str, list of ints, or int String defining the selection as_set : bool, Default False Determines output format. Returns a Set if True, otherwise output is a list. sort_by : str or None, Default: 'auto' Determine how to sort the output if ``as_set`` is False. * 'auto': Let the selection string determine how to sort it out based on its priorities * 'index': Atoms are sorted index, smallest to largest * 'region_order': Atoms are sorted by which region in the provided ``selection_string`` occurs first. So if your expression is ``region1 and region2``, then the output will be atoms which appear in ``region1`` first. Parenthesis are **ignored** so the expression ``region1 and (region2 or region3)`` will prioritize ``region1`` first for sorting. This option only works for expressions on Regions, other selections will fall back to ``None`` * ``None``: Sorting is left up to however the selection string is processed by their respective drivers, or by whatever order the set operations returns it. Not guaranteed to be deterministic in all cases. subset : None, str, list of ints, or int; Optional; Default: None Set of atoms to make a relative selection to. Follows the same rules as ``selection`` for valid inputs. If None, ``selection`` is on whole Topography. Returns ------- selection : list or set Returns the selected atoms as either a list or set based on ``as_set`` keyword. Order of the output is determined by the ``sort_by`` keyword. If a ``as_set`` is ``True``, this option has no effect. """ # Handle subset. Define a subset topology to manipulate, then define a common call to convert subset atom # into absolute atom if subset is not None: subset_atoms = self.select(subset, sort_by='index', as_set=False, subset=None) topology = self.topology.subset(subset_atoms) else: subset_atoms = None topology = self.topology class AtomMap(object): """Atom mapper class""" def __init__(self, subset_atoms): self.subset_atoms = subset_atoms def atom_mapping(self, atom): """Use a "given x, return x" mapping instead of list(range(n_atoms)) or something memory intensive""" if self.subset_atoms is None: return atom else: # Return the mapped atom, only if the atom is actually part of the atom map try: return_atom = self.subset_atoms[atom] except IndexError: return_atom = None return return_atom def __contains__(self, item): if self.subset_atoms is None: return 0 <= item < topology.n_atoms else: return item in self.subset_atoms atom_map = AtomMap(subset_atoms) # Shorthand for later atom_mapping = atom_map.atom_mapping # Helper functions for handling the sorting, atoms should be in absolute terms at this point def sort_output_index(sortable): # Dont do list.sort, its an in place action. return sorted(list(sortable)) def sort_output_region_order(sortable): # Only valid when selection is a string final_output = [] # Determine which regions are in the list region_order = [region_name for region_name in selection.split() if region_name in self] # Cycle through regions for region_name in region_order: region = self.get_region(region_name) # Cycle through atom in region for atom_number in region: # Ensure atom is part of selection output and not previously added # Because only "and" and "or" arguments are allowed, every value in the sortable input is ensured # to be in the regions if atom_number in sortable and atom_number not in final_output: final_output.append(atom_number) return final_output def sort_output_none(sortable): return list(sortable) sortable_dispatch = {'index': sort_output_index, 'region_order': sort_output_region_order, None: sort_output_none} class Selector(abc.ABC): # Implement this to get the valid sort priority. If the sortable is valid, don't include it SORT_PRIORITY = (None,) def __init__(self): """This class and its subclasses are not meant to be used as an instance.""" pass @classmethod @abc.abstractmethod def select(cls, selection_input) -> Union[Tuple[List[int], None], Tuple[None, Exception]]: """Implement this to convert select_string to the output, returning both the output, and the error""" return [0], None @classmethod def sort_selection(cls, selected_atoms) -> Union[List[int], Set[int]]: if as_set: return set(selected_atoms) elif sort_by in cls.SORT_PRIORITY: return sortable_dispatch[sort_by](selected_atoms) else: return sortable_dispatch[cls.SORT_PRIORITY[0]](selected_atoms) # Helper functions for unifying string selection processing class SelectRegion(Selector): SORT_PRIORITY = ('index', 'region_order', None) @classmethod def select(cls, region_string): try: # The self here is inherited from the outer scope region_output_unmapped = list(self._get_region_set(region_string)) region_output = [item for item in region_output_unmapped if item in atom_map] except (SyntaxError, ValueError) as e: # Make this a local variable region_error = e region_output = None else: region_error = None return region_output, region_error class SelectDsl(Selector): SORT_PRIORITY = ('index', None) @classmethod def select(cls, dsl_string): try: mdtraj_output_unmapped_output = (topology.select(dsl_string)).tolist() mdtraj_output = [item for item in map(atom_mapping, mdtraj_output_unmapped_output) if item is not None] except ValueError as e: # Make this a local variable mdtraj_error = e mdtraj_output = None else: mdtraj_error = None return mdtraj_output, mdtraj_error class SelectSmarts(Selector): @classmethod def select(cls, smarts_string): try: # Skeletal structure, not used yet raise NotImplementedError("This method has not been implemented yet") except (NotImplementedError, ValueError) as e: smarts_error = e smarts_output = None return smarts_output, smarts_error class SelectIterable(Selector): # Do not allow these to be sorted as they have been manually specified by integer @classmethod def select(cls, iterable): try: iterable_output = iterable.tolist() except AttributeError: iterable_output = list(iterable) iterable_error = None except Exception as e: iterable_error = e iterable_output = None else: iterable_error = None if iterable_output is not None: iterable_map = [item for item in map(atom_mapping, iterable_output) if item is not None] iterable_output = iterable_map return iterable_output, iterable_error class SelectInt(SelectIterable): # Verbatim copy of Iterable, but change name to help error processing pass # Dispatcher to parse the selection type and return the valid selection classes @functools.singledispatch def selector_picker(selection_input) -> Tuple[Tuple[Selector, ...], str]: if not all([isinstance(i, np.integer) or isinstance(i, int) for i in selection_input]): raise ValueError("Selection {} is not iterable of ints or any other readable type such as string!" "Unable to parse!".format(selection)) return (SelectIterable,), 'iterable' @selector_picker.register(int) def int_selector(_) -> Tuple[Tuple[Selector, ...], str]: return (SelectInt,), 'integer' @selector_picker.register(str) def string_selection(_) -> Tuple[Tuple[Selector, ...], str]: return (SelectRegion, SelectDsl, SelectSmarts), "string" registered_selectors, region_selector_types = selector_picker(selection) selector_outputs = [] selector_errors = [] selector_names = [] # For handling error messages for selector in registered_selectors: selector_names.append(selector.__name__) selection_output, errors = selector.select(selection) selector_outputs.append(selection_output) selector_errors.append(errors) # Show only the valid selectors valid_selectors = [index for index, output in enumerate(selector_outputs) if output is not None] if len(valid_selectors) > 1: # Choose a baseline selector base_selection = selector_outputs[valid_selectors[0]] for index in valid_selectors[1:]: comparison_selection = selector_outputs[valid_selectors[index]] if base_selection != comparison_selection: raise ValueError("The selection {} was ambiguous as the following selectors returned valid " "outputs, but they were different! Consider refining your selection string or " "changing region names to not align with other selection string: \n" " {}".format(selection, [selector_names[index] for index in valid_selectors])) # If we made it here, it does means the selectors are the same, does not mater which we pull from, so # we'll draw from the 0th index at the end elif len(valid_selectors) == 0: base_error_string = "The selection {} could not be parsed by any " \ "selector in the {} class!".format(selection, region_selector_types) base_error_string += ("\nThe following errors were thrown by the selectors which may help you determine " "why the selection was not parsed:") for selector_name, selector_error in zip(selector_names, selector_errors): base_error_string += "\n {}: {}".format(selector_name, selector_error) raise ValueError(base_error_string) return registered_selectors[valid_selectors[0]].sort_selection(selector_outputs[valid_selectors[0]])
# ------------------------------------------------------------------------- # Serialization # ------------------------------------------------------------------------- def __getstate__(self): # We serialize the MDTraj Topology through a pandas dataframe because # it doesn't implement __getstate__ and __setstate__ methods that # guarantee future-compatibility. This serialization protocol will be # compatible at least until the Topology API is broken. atoms, bonds = self._topology.to_dataframe() serialized_topology = {'atoms': atoms.to_json(orient='records'), 'bonds': bonds.tolist()} return dict(topology=serialized_topology, ligand_atoms=self._ligand_atoms, solvent_atoms=self._solvent_atoms, regions=self._regions) def __setstate__(self, serialization): topology_dict = serialization['topology'] atoms = pandas.read_json(topology_dict['atoms'], orient='records') bonds = np.array(topology_dict['bonds']) self._topology = mdtraj.Topology.from_dataframe(atoms, bonds) self._ligand_atoms = serialization['ligand_atoms'] self._solvent_atoms = serialization['solvent_atoms'] self._regions = serialization['regions'] # ------------------------------------------------------------------------- # Internal-usage # ------------------------------------------------------------------------- def _check_existing_regions(self, region_string): """Make sure regions don't overlap""" if region_string in self: raise KeyError("{} is already part of this Topology! " "Cannot overwrite built-in regions!".format(region_string)) def _check_reserved_words(self, region_string): """Make sure region is NOT a protected name""" if region_string in self._PROTECTED_REGION_NAMES: raise KeyError("{} is a protected keyword for logical operations and " "cannot be used as a region name".format(region_string)) def __contains__(self, item): """Check the in operator to see if region is in this class""" return item in self._regions or item in self._BUILT_IN_REGIONS def _get_region_set(self, region_set_string): """ Get a new region as a logical combination of several region sets. See docs in :func:`select` for details about the logic in Complex Parameters ---------- region_set_string : str Region combination string using region names, logical operators, and parenthesis grouping. Returns ------- combined_region : set Set of combined regions """ # Combine regions, start with keys only since we are converting to a set combined_region_keys = tuple(self._regions.keys()) + self._BUILT_IN_REGIONS # Cast regions to set, but only if they are in the region_set_string variables = {key: set(self.get_region(key)) for key in combined_region_keys if key in region_set_string} parsed_output = mmtools.utils.math_eval(region_set_string, variables=variables) return parsed_output
# ============================================================================== # Class that define a single thermodynamic leg (phase) of the calculation # ==============================================================================
[docs]class IMultiStateSampler(mmtools.utils.SubhookedABCMeta): """A sampler for multiple thermodynamic states. This is the interface documents the properties and methods that need to be exposed by the sampler object to be compatible with the class :class:`AlchemicalPhase`. Attributes ---------- number_of_iterations iteration metadata sampler_states """ @property def number_of_iterations(self): """int: the total number of iterations to run.""" pass @abc.abstractproperty def iteration(self): """int: the current iteration.""" pass @abc.abstractproperty def metadata(self): """dict: a copy of the metadata dictionary passed on creation.""" pass @abc.abstractproperty def sampler_states(self): """list of SamplerState: the sampler states at the current iteration.""" pass @abc.abstractmethod
[docs] def create(self, thermodynamic_state, sampler_states, storage, unsampled_thermodynamic_states, metadata): """Create new simulation and initialize the storage. Parameters ---------- thermodynamic_state : list of openmmtools.states.ThermodynamicState The thermodynamic states for the simulation. sampler_states : openmmtools.states.SamplerState or list One or more sets of initial sampler states. If a list of SamplerStates, they will be assigned to thermodynamic states in a round-robin fashion. storage : str or Reporter If str: The path to the storage file. Reads defaults from the :class:`yank.repex.Reporter` class If :class:`yank.repex.Reporter`: Reads the reporter settings for files and options In the future this will be able to take a Storage class as well. unsampled_thermodynamic_states : list of openmmtools.states.ThermodynamicState These are ThermodynamicStates that are not propagated, but their reduced potential is computed at each iteration for each replica. These energy can be used as data for reweighting schemes. metadata : dict Simulation metadata to be stored in the file. """ pass
@abc.abstractmethod
[docs] def minimize(self, tolerance, max_iterations): """Minimize all states. Parameters ---------- tolerance : simtk.unit.Quantity Minimization tolerance (units of energy/mole/length, default is ``1.0 * unit.kilojoules_per_mole / unit.nanometers``). max_iterations : int Maximum number of iterations for minimization. If 0, minimization continues until converged. """ pass
@abc.abstractmethod
[docs] def equilibrate(self, n_iterations, mcmc_moves=None): """Equilibrate all states. Parameters ---------- n_iterations : int Number of equilibration iterations. mcmc_moves : MCMCMove or list of MCMCMove, optional Optionally, the MCMCMoves to use for equilibration can be different from the ones used in production (default is None). """ pass
@abc.abstractmethod
[docs] def run(self, n_iterations=None): """Run the simulation. This runs at most :attr:`number_of_iterations` iterations. Use :func:`extend` to pass the limit. Parameters ---------- n_iterations : int, optional If specified, only at most the specified number of iterations will be run (default is None). """ pass
@abc.abstractmethod
[docs] def extend(self, n_iterations): """Extend the simulation by the given number of iterations. Contrarily to :func:`run`, this will extend the number of iterations past :attr:`number_of_iteration` if requested. Parameters ---------- n_iterations : int The number of iterations to run. """ pass
[docs]class AlchemicalPhase(object): """A single thermodynamic leg (phase) of an alchemical free energy calculation. This class wraps around a general MultiStateSampler and handle the creation of an alchemical free energy calculation. Parameters ---------- sampler : MultiStateSampler The sampler instance implementing the :class:`IMultiStateSampler` interface. Attributes ---------- iteration number_of_iterations is_completed """ def __init__(self, sampler): self._sampler = sampler @staticmethod
[docs] def from_storage(storage): """Static constructor from an existing storage file. Parameters ---------- storage : str or Reporter If str: The path to the primary storage file. Default checkpointing options are stored in this case If :class:`yank.repex.Reporter`: loads from the reporter class, including checkpointing information In the future this will be able to take a Storage class as well. Returns ------- alchemical_phase : AlchemicalPhase A new instance of :class:`AlchemicalPhase` in the same state of the last stored iteration. """ # Check if netcdf file exists. if type(storage) is str: reporter = repex.Reporter(storage) else: reporter = storage if not reporter.storage_exists(): reporter.close() raise RuntimeError('Storage file at {} does not exists; cannot resume.'.format(reporter.filepath)) # TODO: this should skip the Reporter and use the Storage to read storage.metadata. # Open Reporter for reading and read metadata. reporter.open(mode='r') metadata = reporter.read_dict('metadata') reporter.close() # Retrieve the sampler class. sampler_full_name = metadata['sampler_full_name'] module_name, cls_name = sampler_full_name.rsplit('.', 1) module = importlib.import_module(module_name) cls = getattr(module, cls_name) # Resume sampler and return new AlchemicalPhase. sampler = cls.from_storage(reporter) return AlchemicalPhase(sampler)
@property def iteration(self): """int: the current iteration (read-only).""" return self._sampler.iteration @property def number_of_iterations(self): """int: the total number of iterations to run.""" return self._sampler.number_of_iterations @number_of_iterations.setter def number_of_iterations(self, value): self._sampler.number_of_iterations = value @property def is_completed(self): """ Boolean check if if the sampler has been completed by its own determination or if we have exceeded number of iterations """ try: return self._sampler.is_completed except AttributeError: return self._sampler.iteration >= self._sampler.number_of_iterations
[docs] def create(self, thermodynamic_state, sampler_states, topography, protocol, storage, restraint=None, anisotropic_dispersion_cutoff=None, alchemical_regions=None, alchemical_factory=None, metadata=None): """Create a new AlchemicalPhase calculation for a specified protocol. If ``anisotropic_dispersion_cutoff`` is different than ``None``. The end states of the phase will be reweighted. The fully interacting state accounts for: 1. The truncation of nonbonded interactions. 2. The reciprocal space which is not modeled in alchemical states if an Ewald method is used for long-range interactions. Parameters ---------- thermodynamic_state : openmmtools.states.ThermodynamicState Thermodynamic state holding the reference system, temperature and pressure. sampler_states : openmmtools.states.SamplerState or list One or more sets of initial sampler states. If a list of SamplerStates, they will be assigned to replicas in a round-robin fashion. topography : Topography The object holding the topology and labelling the different components of the system. This is used to discriminate between ligand-receptor and solvation systems. protocol : dict The dictionary ``{parameter_name: list_of_parameter_values}`` defining the protocol. All the parameter values list must have the same number of elements. storage : str or initialized Reporter class If str: Path to the storage file. The default checkpointing options (see the :class:`yank.repex.Reporter` class) will be used in this case If :class:`yank.repex.Reporter`: Uses files and checkpointing options of the reporter class passed in restraint : ReceptorLigandRestraint, optional Restraint to add between protein and ligand. This must be specified for ligand-receptor systems in non-periodic boxes. anisotropic_dispersion_cutoff : simtk.openmm.Quantity, 'auto', or None, optional, default None If specified, this is the cutoff at which to reweight long range interactions of the end states to correct for anisotropic dispersions. If `'auto'`, then the distance is automatically chosen based on the minimum possible size it can be given the box volume, then behaves as if a Quantity was passed in. If `None`, the correction won't be applied (units of length, default is None). alchemical_regions : openmmtools.alchemy.AlchemicalRegion, optional If specified, this is the ``AlchemicalRegion`` that will be passed to the ``AbsoluteAlchemicalFactory``, otherwise the ligand will be alchemically modified according to the given protocol. alchemical_factory : openmmtools.alchemy.AbsoluteAlchemicalFactory, optional If specified, this ``AbsoluteAlchemicalFactory`` will be used instead of the one created with default options. metadata : dict, optional Simulation metadata to be stored in the file. """ # Check that protocol has same number of states for each parameter. len_protocol_parameters = {par_name: len(path) for par_name, path in protocol.items()} if len(set(len_protocol_parameters.values())) != 1: raise ValueError('The protocol parameters have a different number ' 'of states: {}'.format(len_protocol_parameters)) # Do not modify passed thermodynamic state. reference_thermodynamic_state = copy.deepcopy(thermodynamic_state) thermodynamic_state = copy.deepcopy(thermodynamic_state) reference_system = thermodynamic_state.system is_periodic = thermodynamic_state.is_periodic is_complex = len(topography.receptor_atoms) > 0 # We currently don't support reaction field. nonbonded_method = mmtools.forces.find_nonbonded_force(reference_system).getNonbondedMethod() if nonbonded_method == openmm.NonbondedForce.CutoffPeriodic: raise RuntimeError('CutoffPeriodic is not supported yet. Use PME for explicit solvent.') # Make sure sampler_states is a list of SamplerStates. if isinstance(sampler_states, mmtools.states.SamplerState): sampler_states = [sampler_states] # Initialize metadata storage and handle default argument. # We'll use the sampler full name for resuming, the reference # thermodynamic state for minimization and the topography for # ligand randomization. if metadata is None: metadata = dict() sampler_full_name = mmtools.utils.typename(self._sampler.__class__) metadata['sampler_full_name'] = sampler_full_name metadata['reference_state'] = mmtools.utils.serialize(thermodynamic_state) metadata['topography'] = mmtools.utils.serialize(topography) # Add default title if user hasn't specified. if 'title' not in metadata: default_title = ('Alchemical free energy calculation created ' 'using yank.AlchemicalPhase and {} on {}') metadata['title'] = default_title.format(sampler_full_name, time.asctime(time.localtime())) # Restraint and standard state correction. # ---------------------------------------- # Add receptor-ligand restraint and compute standard state corrections. restraint_state = None metadata['standard_state_correction'] = 0.0 if is_complex and restraint is not None: logger.debug("Creating receptor-ligand restraints...") try: restraint.restrain_state(thermodynamic_state) except RestraintParameterError: logger.debug('There are undefined restraint parameters. ' 'Trying automatic parametrization.') restraint.determine_missing_parameters(thermodynamic_state, sampler_states[0], topography) restraint.restrain_state(thermodynamic_state) correction = restraint.get_standard_state_correction(thermodynamic_state) # in kT metadata['standard_state_correction'] = correction # Create restraint state that will be part of composable states. restraint_state = RestraintState(lambda_restraints=1.0) # Raise error if we can't find a ligand-receptor to apply the restraint. elif restraint is not None: raise RuntimeError("Cannot apply the restraint. No receptor-ligand " "complex could be found. ") # For not-restrained ligand-receptor periodic systems, we must still # add a standard state correction for the box volume. elif is_complex and is_periodic: # TODO: What if the box volume fluctuates during the simulation? box_vectors = reference_system.getDefaultPeriodicBoxVectors() box_volume = mmtools.states._box_vectors_volume(box_vectors) metadata['standard_state_correction'] = - np.log(V0 / box_volume) # For implicit solvent/vacuum complex systems, we require a restraint # to keep the ligand from drifting too far away from receptor. elif is_complex and not is_periodic: raise ValueError('A receptor-ligand system in implicit solvent or ' 'vacuum requires a restraint.') # Create alchemical states. # ------------------------- # Handle default alchemical region. if alchemical_regions is None: alchemical_regions = self._build_default_alchemical_region( reference_system, topography, protocol) # Check that we have atoms to alchemically modify. if len(alchemical_regions.alchemical_atoms) == 0: raise ValueError("Couldn't find atoms to alchemically modify.") # Create alchemically-modified system using alchemical factory. logger.debug("Creating alchemically-modified states...") if alchemical_factory is None: alchemical_factory = mmtools.alchemy.AbsoluteAlchemicalFactory(disable_alchemical_dispersion_correction=True) alchemical_system = alchemical_factory.create_alchemical_system(thermodynamic_state.system, alchemical_regions) # Create compound alchemically modified state to pass to sampler. thermodynamic_state.system = alchemical_system alchemical_state = mmtools.alchemy.AlchemicalState.from_system(alchemical_system) if restraint_state is not None: composable_states = [alchemical_state, restraint_state] else: composable_states = [alchemical_state] compound_state = mmtools.states.CompoundThermodynamicState( thermodynamic_state=thermodynamic_state, composable_states=composable_states) # Create all compound states to pass to sampler.create() # following the requested protocol. compound_states = [] protocol_keys, protocol_values = zip(*protocol.items()) for state_id, state_values in enumerate(zip(*protocol_values)): compound_states.append(copy.deepcopy(compound_state)) for lambda_key, lambda_value in zip(protocol_keys, state_values): if hasattr(compound_state, lambda_key): setattr(compound_states[state_id], lambda_key, lambda_value) else: raise AttributeError('CompoundThermodynamicState object does not ' 'have protocol attribute {}'.format(lambda_key)) # Temperature and pressure at the end states should # be the same or the analysis won't make sense. for state_property in ['temperature', 'pressure']: if getattr(compound_states[0], state_property) != getattr(compound_states[-1], state_property): raise ValueError('The {}s of the end states must be the same.'.format(state_property)) # Expanded cutoff unsampled states. # --------------------------------- # TODO should we allow expanded states for non-periodic systems? logger.debug('Creating expanded cutoff states...') expanded_cutoff_states = [] if is_periodic and anisotropic_dispersion_cutoff is not None: # Create non-alchemically modified state with an expanded cutoff. reference_state_expanded = self._expand_state_cutoff(reference_thermodynamic_state, anisotropic_dispersion_cutoff) # Add the restraint if any. The free energy of removing the restraint # will be taken into account with the standard state correction. if restraint is not None: restraint.restrain_state(reference_state_expanded) # The value of lambda_restraints must be the same as the first state. # TODO: handle case with multiple restraints. restraint_state.lambda_restraints = compound_states[0].lambda_restraints reference_state_expanded = mmtools.states.CompoundThermodynamicState( thermodynamic_state=reference_state_expanded, composable_states=[restraint_state]) # Create the expanded cutoff decoupled state. last_state_expanded = self._expand_state_cutoff(compound_states[-1], anisotropic_dispersion_cutoff) expanded_cutoff_states = [reference_state_expanded, last_state_expanded] elif anisotropic_dispersion_cutoff is not None: logger.warning("The requested anisotropic dispersion correction " "won't be computed since the system is non-periodic.") # Create simulation. # ------------------ logger.debug("Creating sampler object...") self._sampler.create(compound_states, sampler_states, storage=storage, unsampled_thermodynamic_states=expanded_cutoff_states, metadata=metadata)
[docs] def minimize(self, tolerance=1.0*unit.kilojoules_per_mole/unit.nanometers, max_iterations=0): """Minimize all the states. The minimization is performed in two steps. In the first one, the positions are minimized in the reference thermodynamic state (i.e. non alchemically-modified). Only then, the positions are minimized in their alchemically softened state. Parameters ---------- tolerance : simtk.unit.Quantity, optional Minimization tolerance (units of energy/mole/length, default is ``1.0 * unit.kilojoules_per_mole / unit.nanometers``). max_iterations : int, optional Maximum number of iterations for minimization. If 0, minimization continues until converged. """ metadata = self._sampler.metadata serialized_reference_state = metadata['reference_state'] reference_state = mmtools.utils.deserialize(serialized_reference_state) sampler_states = self._sampler.sampler_states # We minimize only the sampler states that are in different positions. # This depends on how many sampler states have been passed in create() # and if the ligand has been randomized before calling minimize(). similar_sampler_states = self._find_similar_sampler_states(sampler_states) logger.debug('Minimizing {} sampler states in the reference ' 'thermodynamic state'.format(len(similar_sampler_states))) # Distribute minimization across nodes. minimized_sampler_states_ids = list(similar_sampler_states.keys()) minimized_positions = mpi.distribute(self._minimize_sampler_state, minimized_sampler_states_ids, sampler_states, reference_state, tolerance, max_iterations, send_results_to='all') # Update all sampler states. for sampler_state_id, minimized_pos in zip(minimized_sampler_states_ids, minimized_positions): sampler_states[sampler_state_id].positions = minimized_pos for similar_sampler_state_id in similar_sampler_states[sampler_state_id]: sampler_states[similar_sampler_state_id].positions = minimized_pos # Update sampler and perform second minimization in alchemically modified states. self._sampler.sampler_states = sampler_states self._sampler.minimize(tolerance=tolerance, max_iterations=max_iterations)
[docs] def randomize_ligand(self, sigma_multiplier=2.0, close_cutoff=1.5*unit.angstrom): """Randomize the ligand positions in every state. The position and orientation of the ligand in each state will be randomized. This works only if the system is a ligand-receptor system. If you call this before minimizing, each positions will be minimized separately in the reference state, so you may want to call it afterwards to speed up minimization. Parameters ---------- sigma_multiplier : float, optional The ligand will be placed close to a random receptor atom at a distance that is normally distributed with standard deviation ``sigma_multiplier * receptor_radius_of_gyration`` (default is 2.0). close_cutoff : simtk.unit.Quantity, optional Each random placement proposal will be rejected if the ligand ends up being closer to the receptor than this cutoff (units of length, default is ``1.5*unit.angstrom``). """ metadata = self._sampler.metadata serialized_topography = metadata['topography'] topography = mmtools.utils.deserialize(serialized_topography) # We can randomize the ligand only in implicit solvent. is_complex = len(topography.ligand_atoms) > 0 is_explicit = len(topography.solvent_atoms) > 0 if not is_complex: raise RuntimeError('Cannot find ligand atoms to randomize.') if is_explicit: raise RuntimeError('Cannot randomize ligand in explict solvent.') # Randomize all sampler states. sampler_states = self._sampler.sampler_states ligand_positions = mpi.distribute(self._randomize_ligand, sampler_states, topography, sigma_multiplier, close_cutoff, send_results_to='all') # Update sampler states with randomized positions. for sampler_state, ligand_pos in zip(sampler_states, ligand_positions): sampler_state.positions[topography.ligand_atoms] = ligand_pos self._sampler.sampler_states = sampler_states
[docs] def equilibrate(self, n_iterations, mcmc_moves=None): """Equilibrate all states. Parameters ---------- n_iterations : int Number of equilibration iterations. mcmc_moves : MCMCMove or list of MCMCMove, optional Optionally, the MCMCMoves to use for equilibration can be different from the ones used in production. """ self._sampler.equilibrate(n_iterations=n_iterations, mcmc_moves=mcmc_moves)
[docs] def run(self, n_iterations=None): """Run the alchemical phase simulation. Parameters ---------- n_iterations : int, optional If specified, only at most the specified number of iterations will be run (default is None). """ self._sampler.run(n_iterations=n_iterations)
[docs] def extend(self, n_iterations): """Extend the simulation by the given number of iterations. Parameters ---------- n_iterations : int The number of iterations to run. """ self._sampler.extend(n_iterations)
# ------------------------------------------------------------------------- # Internal-usage # ------------------------------------------------------------------------- @staticmethod def _expand_state_cutoff(thermodynamic_state, expanded_cutoff_distance, replace_reaction_field=False, switch_width=None): """Expand the thermodynamic state cutoff to the given one. If replace_reaction_field is True, the system will be modified to use an UnshiftedReactionFieldForce. In this case switch_width must be specified. """ # If we use a barostat we leave more room for volume fluctuations or # we risk fatal errors. This is how much we allow the box size to change. fluctuation_size = 0.8 # Do not modify passed thermodynamic state. thermodynamic_state = copy.deepcopy(thermodynamic_state) system = thermodynamic_state.system # Determine minimum box side dimension. The theoretical maximal allowed cutoff # is given by half the norm of the smallest vector, but OpenMM limits it to # the minimum diagonal element of the box vector matrix for efficiency. box_vectors = system.getDefaultPeriodicBoxVectors() min_box_dimension = min([vector[i] for i, vector in enumerate(box_vectors)]) # Determine cutoff automatically if requested. # We leave more space if the volume fluctuates. if expanded_cutoff_distance == 'auto': if thermodynamic_state.pressure is None: expanded_cutoff_distance = min_box_dimension * 0.99 / 2.0 else: expanded_cutoff_distance = min_box_dimension * fluctuation_size / 2.0 expanded_cutoff_distance = min(expanded_cutoff_distance, 16*unit.angstroms) # Otherwise check that requested cutoff is within fluctuation limits. If the # state is in NVT and the cutoff is too big, OpenMM will raise an exception # on Context creation. elif (thermodynamic_state.pressure is not None and min_box_dimension * fluctuation_size < 2.0 * expanded_cutoff_distance): raise RuntimeError('Barostated box sides must be at least {} Angstroms ' 'to correct for missing dispersion interactions. The ' 'minimum dimension of the provided box is {} Angstroms' ''.format(expanded_cutoff_distance/unit.angstrom * 2, min_box_dimension/unit.angstrom)) logger.debug('Setting cutoff for fully interacting system to {}. The minimum box ' 'dimension is {}.'.format(expanded_cutoff_distance, min_box_dimension)) # Expanded forces cutoff. for force in system.getForces(): try: force_cutoff = force.getCutoffDistance() except AttributeError: pass else: # We don't want to reduce the cutoff if it's already large. if force_cutoff < expanded_cutoff_distance: cutoff_diff = expanded_cutoff_distance - force_cutoff switching_distance = force.getSwitchingDistance() # Expand cutoff preserving the original switch width. # We don't need to check if we are using a switch since # there is a setting for that. force.setCutoffDistance(expanded_cutoff_distance) force.setSwitchingDistance(switching_distance + cutoff_diff) # Replace reaction field NonbondedForce to remove constant shift term. # AbsoluteAlchemicalFactory already does it for the other states. if replace_reaction_field: mmtools.forcefactories.replace_reaction_field(system, return_copy=False, switch_width=switch_width) # Return the new thermodynamic state with the expanded cutoff. thermodynamic_state.system = system return thermodynamic_state @staticmethod def _build_default_alchemical_region(system, topography, protocol): """Create a default AlchemicalRegion if the user hasn't provided one.""" # TODO: we should probably have a second region that annihilate sterics of counterions. alchemical_region_kwargs = {} # Modify ligand if this is a receptor-ligand phase, or # solute if this is a transfer free energy calculation. if len(topography.ligand_atoms) > 0: alchemical_region_name = 'ligand_atoms' else: alchemical_region_name = 'solute_atoms' alchemical_atoms = getattr(topography, alchemical_region_name) # In periodic systems, we alchemically modify the ligand/solute # counterions to make sure that the solvation box is always neutral. if system.usesPeriodicBoundaryConditions(): alchemical_counterions = pipeline.find_alchemical_counterions( system, topography, alchemical_region_name) alchemical_atoms += alchemical_counterions # Sort them by index for safety. We don't want to # accidentally exchange two atoms' positions. alchemical_atoms = sorted(alchemical_atoms) alchemical_region_kwargs['alchemical_atoms'] = alchemical_atoms # Check if we need to modify bonds/angles/torsions. for element_type in ['bonds', 'angles', 'torsions']: if 'lambda_' + element_type in protocol: modify_it = True else: modify_it = None alchemical_region_kwargs['alchemical_' + element_type] = modify_it # Create alchemical region. alchemical_region = mmtools.alchemy.AlchemicalRegion(**alchemical_region_kwargs) return alchemical_region @staticmethod def _find_similar_sampler_states(sampler_states): """Groups SamplerStates that have the same positions. Returns ------- similar_sampler_states : dict The dict sampler_state_index: list_of_sampler_state_indices with same positions. """ # similar_sampler_states is an ordered dict # sampler_state_index: list of sampler_state_indices with same positions # we run only 1 minimization for each of these entries. similar_sampler_states = collections.OrderedDict() # processed_sampler_states_ids is a set containing all the # sampler state indices that have been assigned a minimization. processed_sampler_states_ids = set() # Find minimum number of minimizations required. for state_id, sampler_state in enumerate(sampler_states): if state_id in processed_sampler_states_ids: continue similar_sampler_states[state_id] = [] processed_sampler_states_ids.add(state_id) for next_state_id in range(state_id+1, len(sampler_states)): next_sampler_state = sampler_states[next_state_id] if np.allclose(sampler_state.positions, next_sampler_state.positions): similar_sampler_states[state_id].append(next_state_id) processed_sampler_states_ids.add(next_state_id) return similar_sampler_states @staticmethod def _minimize_sampler_state(sampler_state_id, sampler_states, thermodynamic_state, tolerance, max_iterations): """Minimize the specified sampler state at the given thermodynamic state.""" sampler_state = sampler_states[sampler_state_id] # Retrieve a context. Any Integrator works. context, integrator = mmtools.cache.global_context_cache.get_context(thermodynamic_state) # Set initial positions and box vectors. sampler_state.apply_to_context(context) # Compute the initial energy of the system for logging. initial_energy = thermodynamic_state.reduced_potential(context) logger.debug('Sampler state {}/{}: initial energy {:8.3f}kT'.format( sampler_state_id + 1, len(sampler_states), initial_energy)) # Minimize energy. openmm.LocalEnergyMinimizer.minimize(context, tolerance, max_iterations) # Get the minimized positions. sampler_state.update_from_context(context) # Compute the final energy of the system for logging. final_energy = thermodynamic_state.reduced_potential(sampler_state) logger.debug('Sampler state {}/{}: final energy {:8.3f}kT'.format( sampler_state_id + 1, len(sampler_states), final_energy)) # Return minimized positions. return sampler_state.positions @staticmethod def _randomize_ligand(sampler_state, topography, sigma_multiplier, close_cutoff): """Randomize ligand positions of the given sampler state.""" # Shortcut variables. ligand_atoms = topography.ligand_atoms receptor_atoms = topography.receptor_atoms # We set the standard deviation of the displacement # proportional to the receptor radius of gyration. radius_of_gyration = pipeline.compute_radius_of_gyration(sampler_state.positions[receptor_atoms]) sigma = sigma_multiplier * radius_of_gyration # Convert to dimensionless positions. positions_unit = sampler_state.positions.unit x = sampler_state.positions / positions_unit close_cutoff = close_cutoff / positions_unit # We work with Quantity only for ligand atoms for readability. ligand_positions = x[ligand_atoms] * positions_unit # Try until we have a non-overlapping ligand conformation. max_n_attempts = 5000 n_attempts = 0 while n_attempts <= max_n_attempts: # Center ligand on a random receptor atom. ligand_positions_mean = ligand_positions.mean(0) receptor_atom_index = receptor_atoms[np.random.randint(0, len(receptor_atoms))] ligand_positions[:] += sampler_state.positions[receptor_atom_index] - ligand_positions_mean # Randomize ligand orientation and displace. ligand_positions = mmtools.mcmc.MCRotationMove.rotate_positions(ligand_positions) ligand_positions = mmtools.mcmc.MCDisplacementMove.displace_positions(ligand_positions, sigma) # Update array to compute distances. x[ligand_atoms, :] = ligand_positions / positions_unit # Check if there's overlap. min_dist = pipeline.compute_min_dist(x[ligand_atoms], x[receptor_atoms]) if min_dist >= close_cutoff: break n_attempts += 1 # Check if we could find a working configuration. if n_attempts > max_n_attempts: raise RuntimeError('Could not randomize ligand after {} attempts'.format(max_n_attempts)) # We return only the randomized ligand positions to minimize MPI traffic. return ligand_positions