#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
MultiStateAnalyzers
===================
Analysis tools and module for MultiStateSampler simulations. Provides programmatic and automatic
"best practices" integration to determine free energy and other observables.
Fully extensible to support new samplers and observables.
"""
# =============================================================================================
# MODULE IMPORTS
# =============================================================================================
import abc
import copy
import inspect
import logging
import re
from typing import Optional, NamedTuple, Union
import mdtraj
import numpy as np
from simtk import openmm
import simtk.unit as units
from scipy.misc import logsumexp
from pymbar import MBAR, timeseries
import openmmtools as mmtools
from . import utils
ABC = abc.ABC
logger = logging.getLogger(__name__)
__all__ = [
'PhaseAnalyzer',
'MultiStateSamplerAnalyzer',
'MultiPhaseAnalyzer',
'ObservablesRegistry',
'default_observables_registry'
]
# =============================================================================================
# GLOBAL VARIABLES
# =============================================================================================
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
_OPENMM_ENERGY_UNIT = units.kilojoules_per_mole
_MDTRAJ_DISTANCE_UNIT = units.nanometers
# =============================================================================================
# UTILITY FUNCTIONS
# =============================================================================================
def compute_centroid_distance(positions_group1, positions_group2, weights_group1, weights_group2):
"""Compute the distance between the centers of mass of the two groups.
The two positions given must have the same units.
Parameters
----------
positions_group1 : numpy.array
The positions of the particles in the first CustomCentroidBondForce group.
positions_group2 : numpy.array
The positions of the particles in the second CustomCentroidBondForce group.
weights_group1 : list of float
The mass of the particle in the first CustomCentroidBondForce group.
weights_group2 : list of float
The mass of the particles in the second CustomCentroidBondForce group.
"""
assert len(positions_group1) == len(weights_group1)
assert len(positions_group2) == len(weights_group2)
# Compute center of mass for each group.
com_group1 = np.average(positions_group1, axis=0, weights=weights_group1)
com_group2 = np.average(positions_group2, axis=0, weights=weights_group2)
# Compute distance between centers of mass.
distance = np.linalg.norm(com_group1 - com_group2)
return distance
# =============================================================================================
# MODULE CLASSES
# =============================================================================================
[docs]class ObservablesRegistry(object):
"""
Registry of computable observables.
This is a class accessed by the :class:`PhaseAnalyzer` objects to check
which observables can be computed, and then provide a regular categorization of them.
This registry is a required linked component of any PhaseAnalyzer and especially of the MultiPhaseAnalyzer.
This is not an internal class to the PhaseAnalyzer however because it can be instanced, extended, and customized
as part of the API for this module.
To define your own methods:
1) Choose a unique observable name.
2) Categorize the observable in one of the following ways by adding to the list in the "observables_X" method:
2a) "defined_by_phase":
Depends on the Phase as a whole (state independent)
2b) "defined_by_single_state":
Computed entirely from one state, e.g. Radius of Gyration
2c) "defined_by_two_states":
Property is relative to some reference state, such as Free Energy Difference
3) Optionally categorize the error category calculation in the "observables_with_error_adding_Y" methods
If not placed in an error category, the observable will be assumed not to carry error
Examples: A, B, C are the observable in 3 phases, eA, eB, eC are the error of the observable in each phase
3a) "linear": Error between phases adds linearly.
If C = A + B, eC = eA + eB
3b) "quadrature": Error between phases adds in the square.
If C = A + B, eC = sqrt(eA^2 + eB^2)
4) Finally, to add this observable to the phase, implement a "get_{method name}" method to the subclass of
:class:`YankPhaseAnalyzer`. Any :class:`MultiPhaseAnalyzer` composed of this phase will automatically have the
"get_{method name}" if all other phases in the :class:`MultiPhaseAnalyzer` have the same method.
"""
def __init__(self):
"""Register Defaults"""
# Create empty registry
self._observables = {'two_state': set(),
'one_state': set(),
'phase': set()}
self._errors = {'quad': set(),
'linear': set(),
None: set()}
[docs] def register_two_state_observable(self, name: str,
error_class: Optional[str]=None,
re_register: bool=False):
"""
Register a new two state observable, or re-register an existing one.
Parameters
----------
name: str
Name of the observable, will be cast to all lower case and spaces replaced with underscores
error_class: "quad", "linear", or None
How the error of the observable is computed when added with other errors from the same observable.
* "quad": Adds in the quadrature, Observable C = A + B, Error eC = sqrt(eA**2 + eB**2)
* "linear": Adds linearly, Observable C = A + B, Error eC = eA + eB
* None: Does not carry error
re_register: bool, optional, Default: False
Re-register an existing observable
"""
self._register_observable(name, "two_state", error_class, re_register=re_register)
[docs] def register_one_state_observable(self, name: str,
error_class: Optional[str]=None,
re_register: bool=False):
"""
Register a new one state observable, or re-register an existing one.
Parameters
----------
name: str
Name of the observable, will be cast to all lower case and spaces replaced with underscores
error_class: "quad", "linear", or None
How the error of the observable is computed when added with other errors from the same observable.
* "quad": Adds in the quadrature, Observable C = A + B, Error eC = sqrt(eA**2 + eB**2)
* "linear": Adds linearly, Observable C = A + B, Error eC = eA + eB
* None: Does not carry error
re_register: bool, optional, Default: False
Re-register an existing observable
"""
self._register_observable(name, "one_state", error_class, re_register=re_register)
[docs] def register_phase_observable(self, name: str,
error_class: Optional[str]=None,
re_register: bool=False):
"""
Register a new observable defined by phaee, or re-register an existing one.
Parameters
----------
name: str
Name of the observable, will be cast to all lower case and spaces replaced with underscores
error_class: 'quad', 'linear', or None
How the error of the observable is computed when added with other errors from the same observable.
* 'quad': Adds in the quadrature, Observable C = A + B, Error eC = sqrt(eA**2 + eB**2)
* 'linear': Adds linearly, Observable C = A + B, Error eC = eA + eB
* None: Does not carry error
re_register: bool, optional, Default: False
Re-register an existing observable
"""
self._register_observable(name, "phase", error_class, re_register=re_register)
########################
# Define the observables
########################
@property
def observables(self):
"""
Set of observables which are derived from the subsets below
"""
observables = set()
for subset_key in self._observables:
observables |= self._observables[subset_key]
return tuple(observables)
# ------------------------------------------------
# Exclusive Observable categories
# The intersection of these should be the null set
# ------------------------------------------------
@property
def observables_defined_by_two_states(self):
"""
Observables that require an i and a j state to define the observable accurately between phases
"""
return self._get_observables('two_state')
@property
def observables_defined_by_single_state(self):
"""
Defined observables which are fully defined by a single state, and not by multiple states such as differences
"""
return self._get_observables('one_state')
@property
def observables_defined_by_phase(self):
"""
Observables which are defined by the phase as a whole, and not defined by any 1 or more states
e.g. Standard State Correction
"""
return self._get_observables('phase')
##########################################
# Define the observables which carry error
# This should be a subset of observables
##########################################
@property
def observables_with_error(self):
"""Determine which observables have error by inspecting the the error subsets"""
observables = set()
for subset_key in self._errors:
if subset_key is not None:
observables |= self._errors[subset_key]
return tuple(observables)
# ------------------------------------------------
# Exclusive Error categories
# The intersection of these should be the null set
# ------------------------------------------------
@property
def observables_with_error_adding_quadrature(self):
"""Observable C = A + B, Error eC = sqrt(eA**2 + eB**2)"""
return self._get_errors('quad')
@property
def observables_with_error_adding_linear(self):
"""Observable C = A + B, Error eC = eA + eB"""
return self._get_errors('linear')
@property
def observables_without_error(self):
return self._get_errors(None)
# ------------------
# Internal functions
# ------------------
def _get_observables(self, key):
return tuple(self._observables[key])
def _get_errors(self, key):
return tuple(self._errors[key])
@staticmethod
def _cast_observable_name(name) -> str:
return re.sub(" +", "_", name.lower())
def _register_observable(self, obs_name: str,
obs_calc_class: str,
obs_error_class: Union[None, str],
re_register: bool=False):
obs_name = self._cast_observable_name(obs_name)
if not re_register and obs_name in self.observables:
raise ValueError("{} is already a registered observable! "
"Consider setting re_register key!".format(obs_name))
self._check_obs_class(obs_calc_class)
self._check_obs_error_class(obs_error_class)
obs_name_set = {obs_name} # set(single_object) throws an error, set(string) splits each char
# Throw out existing observable if present (set difference)
for obs_key in self._observables:
self._observables[obs_key] -= obs_name_set
for obs_err_key in self._errors:
self._errors[obs_err_key] -= obs_name_set
# Add new observable to correct classifiers (set union)
self._observables[obs_calc_class] |= obs_name_set
self._errors[obs_error_class] |= obs_name_set
def _check_obs_class(self, obs_class):
assert obs_class in self._observables, "{} not a known observable class!".format(obs_class)
def _check_obs_error_class(self, obs_error):
assert obs_error is None or obs_error in self._errors, \
"{} not a known observable error class!".format(obs_error)
# Create a default registry and register some stock values
default_observables_registry = ObservablesRegistry()
default_observables_registry.register_two_state_observable('free_energy', error_class='quad')
default_observables_registry.register_two_state_observable('entropy', error_class='quad')
default_observables_registry.register_two_state_observable('enthalpy', error_class='quad')
# -----------------------------------------------------------------------------
# EXCEPTIONS.
# -----------------------------------------------------------------------------
class InsufficientData(Exception):
"""Raised when the data is not sufficient perform the requested analysis."""
pass
# -----------------------------------------------------------------------------
# CACHED PROPERTIES DESCRIPTOR.
# -----------------------------------------------------------------------------
class CachedProperty(object):
"""Analyzer helper descriptor of a cached value with a dependency graph.
Automatically takes care of invalidating the values of the cache
that depend on this property.
Parameters
----------
name : str
The name of the parameter in the cache.
dependencies : iterable of str
List of cached properties on which this property depends.
check_changes : bool, optional
If True, the cache dependencies will be invalidated only if
the new value differs from the old one (default is False).
default : object, optional
The default value in case the cache doesn't contain a value
for this. If a callable, this function must have the signature
``default(self, instance)``. It is also possible to define a
callable default through the ``default`` decorator. After the
first cache miss, the default value is cached. By default,
AttributeError is raised on a cache miss.
validator : callable, optional
A function to call before setting a new value with signature
``validator(self, instance, new_value)``. It is also possible
to define this through the ``validator`` decorator.
"""
def __init__(self, name, dependencies=(), check_changes=False,
default=AttributeError, validator=None):
# Reserved names.
# TODO make observables CachedProperties?
assert name != 'observables'
assert name != 'reporter'
# TODO use __setname__() when dropping Python 3.5 support.
self.name = name
self.dependencies = dependencies
self._default = default
self._validator = validator
self._check_changes = check_changes
def __get__(self, instance, owner_class=None):
# If called as a class descriptor, return the descriptor.
if instance is None:
return self
# Check if the value is cached and fall back to default value.
try:
value = instance._cache[self.name]
except KeyError:
value = self._get_default(instance)
# Cache default value for next use.
instance._update_cache(self.name, value, self._check_changes)
return value
def __set__(self, instance, new_value):
if self._validator is not None:
new_value = self._validator(self, instance, new_value)
instance._update_cache(self.name, new_value, self._check_changes)
def validator(self, validator):
return type(self)(self.name, self.dependencies, self._check_changes, self._default, validator)
def default(self, default):
return type(self)(self.name, self.dependencies, self._check_changes, default, self._validator)
def _get_default(self, instance):
if self._default is AttributeError:
err_msg = 'Reference before assignment {}.{}'.format(instance, self.name)
raise AttributeError(err_msg)
elif callable(self._default):
value = self._default(self, instance)
else:
value = self._default
return value
# ---------------------------------------------------------------------------------------------
# Phase Analyzers
# ---------------------------------------------------------------------------------------------
[docs]class PhaseAnalyzer(ABC):
"""
Analyzer for a single phase of a MultiState simulation.
Uses the reporter from the simulation to determine the location
of all variables.
To compute a specific observable in an implementation of this class, add it to the ObservableRegistry and then
implement a ``get_X`` where ``X`` is the name of the observable you want to compute. See the ObservablesRegistry for
information about formatting the observables.
Analyzer works in units of kT unless specifically stated otherwise. To convert back to a unit set, just multiply by
the .kT property.
A PhaseAnalyzer also needs an ObservablesRegistry to track how to handle each observable given implemented within
for things like error and cross-phase analysis.
Parameters
----------
reporter : MultiStateReporter instance
Reporter from MultiState which ties to the simulation data on disk.
name : str, Optional
Unique name you want to assign this phase, this is the name that will appear in :class:`MultiPhaseAnalyzer`'s.
If not set, it will be given the arbitrary name "phase#" where # is an integer, chosen in order that it is
assigned to the :class:`MultiPhaseAnalyzer`.
max_n_iterations : int, optional
The maximum number of iterations to analyze. If not provided, all
the iterations will be analyzed.
reference_states : tuple of ints, length 2, Optional, Default: (0,-1)
Integers ``i`` and ``j`` of the state that is used for reference in observables, "O". These values are only used
when reporting single numbers or combining observables through :class:`MultiPhaseAnalyzer` (since the number of
states between phases can be different). Calls to functions such as ``get_free_energy`` in a single Phase
results in the O being returned for all states.
For O completely defined by the state itself (i.e. no differences between states, e.g. Temperature),
only O[i] is used
For O where differences between states are required (e.g. Free Energy): O[i,j] = O[j] - O[i]
For O defined by the phase as a whole, the reference states are not needed.
analysis_kwargs : None or dict, optional
Dictionary of extra keyword arguments to pass into the analysis tool, typically MBAR.
For instance, the initial guess of relative free energies to give to MBAR would be something like:
``{'initial_f_k':[0,1,2,3]}``
registry : ObservablesRegistry instance
Instanced ObservablesRegistry with all observables implemented through a ``get_X`` function classified and
registered. Any cross-phase analysis must use the same instance of an ObservablesRegistry
Attributes
----------
name
observables
max_n_iterations
reference_states
n_iterations
n_replicas
n_states
kT
reporter
registry
See Also
--------
ObservablesRegistry
"""
def __init__(self, reporter, name=None, reference_states=(0, -1),
max_n_iterations=None, analysis_kwargs=None,
registry=default_observables_registry):
"""
The reporter provides the hook into how to read the data, all other options control where differences are
measured from and how each phase interfaces with other phases.
"""
# Arguments validation.
if type(reporter) is str:
raise ValueError('reporter must be a MultiStateReporter instance')
if not isinstance(registry, ObservablesRegistry):
raise ValueError("Registry must be an instanced ObservablesRegistry")
if analysis_kwargs is None:
analysis_kwargs = {}
elif not isinstance(analysis_kwargs, dict):
raise ValueError('analysis_kwargs must be either None or a dictionary')
self.registry = registry
if not reporter.is_open():
reporter.open(mode='r')
self._reporter = reporter
# Initialize cached observables so the phase can be retrieved once computed.
self._computed_observables = {observable: None for observable in self.observables}
# Internal properties
self._name = name
# Start as default sign +, handle all sign conversion at preparation time
self._sign = '+'
self._reference_states = None # Initialize the cache object.
self.reference_states = reference_states
self._extra_analysis_kwargs = analysis_kwargs
# Initialize cached values that are read or derived from the Reporter.
self._cache = {} # This cache should be always set with _update_cache().
self.clear()
self.max_n_iterations = max_n_iterations
[docs] def clear(self):
"""Reset all cached objects.
This must to be called if the information in the reporter changes
after analysis.
"""
# Reset cached values that are read directly from the Reporter.
self._n_iterations = None
self._n_replicas = None
self._end_thermodynamic_states = None
self._kT = None
# Reset cached values that are derived from the reporter.
self._invalidate_cache_values('reporter')
@property
def name(self):
"""User-readable string name of the phase"""
return self._name
@name.setter
def name(self, value):
self._name = value
@property
def observables(self):
"""List of observables that the instanced analyzer can compute/fetch."""
# Auto-determine the computable observables by inspection of non-flagged methods
# We determine valid observables by negation instead of just having each child
# implement the method to enforce uniform function naming conventions.
observables = []
for observable in self.registry.observables:
if hasattr(self, "get_" + observable):
observables.append(observable)
# Cast observables to an immutable.
return tuple(observables)
@property
def reference_states(self):
"""Tuple of reference states ``i`` and ``j`` for :class:`MultiPhaseAnalyzer` instances"""
return self._reference_states
@reference_states.setter
def reference_states(self, value):
"""Provide a way to re-assign the ``i, j`` states in a protected way"""
i, j = value[0], value[1]
if type(i) is not int or type(j) is not int:
raise ValueError("reference_states must be a length 2 iterable of ints")
self._reference_states = (i, j)
@property
def n_iterations(self):
"""int: The total number of iterations of the phase."""
if self._n_iterations is None:
# The + 1 accounts for iteration 0.
self._n_iterations = self._reporter.read_last_iteration(full_iteration=False)
return self._n_iterations
@property
def n_replicas(self):
"""int: Number of replicas."""
if self._n_replicas is None:
replica_state_indices = self._reporter.read_replica_thermodynamic_states(iteration=0)
self._n_replicas = len(replica_state_indices)
return self._n_replicas
@property
def n_states(self):
"""int: Number of sampled thermodynamic states."""
return self._reporter.n_states
def _get_end_thermodynamic_states(self):
"""Read thermodynamic states at the ends of the protocol."""
# TODO: what if reference_states changes? This should become a CachedProperty.
# Check cached values.
if self._end_thermodynamic_states is None:
sampled_states, unsampled_states = self._reporter.read_thermodynamic_states()
if len(unsampled_states) == 0:
self._end_thermodynamic_states = [sampled_states[0], sampled_states[-1]]
else:
self._end_thermodynamic_states = unsampled_states
# Cache other useful informations since we have already read this.
# TODO should we read temperatures of all the states and let kT property depend on reference_states?
self._kT = self._end_thermodynamic_states[0].kT
return self._end_thermodynamic_states
@property
def kT(self):
"""
Quantity of boltzmann constant times temperature of the phase in units of energy per mol
Allows conversion between dimensionless energy and unit bearing energy
"""
if self._kT is None:
self._get_end_thermodynamic_states()
return self._kT
@property
def reporter(self):
"""Sampler Reporter tied to this object."""
return self._reporter
@reporter.setter
def reporter(self, value):
"""Make sure users cannot overwrite the reporter."""
raise ValueError("You cannot re-assign the reporter for this analyzer!")
# -------------------------------------------------------------------------
# Cached properties functions/classes.
# -------------------------------------------------------------------------
@classmethod
def _get_cache_dependency_graph(cls):
"""dict: cached_value -> list of cache values to invalidate."""
# Retrieve all cached properties.
cached_properties = {value for name, value in inspect.getmembers(cls)
if isinstance(value, CachedProperty)}
# Build the dependency graph.
dependency_graph = {}
for cached_property in cached_properties:
for dependency in cached_property.dependencies:
try:
dependency_graph[dependency].add(cached_property.name)
except KeyError:
dependency_graph[dependency] = {cached_property.name}
# Hard-code observable dependency since those are not CachedProperties.
# TODO make observables CachedProperties?
dependency_graph['mbar'] = {'observables'}
return dependency_graph
def _update_cache(self, key, new_value, check_changes=False):
"""Update the cache entry and invalidate the values that depend on it.
Parameters
----------
key : str
The name of the value to update.
new_value : object
The new value of the key.
check_changes : bool, optional
If True and the new value is equal to the current one,
the dependent cache values are not invalidated.
"""
invalidate_cache = True
try:
old_value = self._cache[key]
except KeyError:
invalidate_cache = False
else:
if check_changes and old_value == new_value:
invalidate_cache = False
# Update value and invalidate the cache.
self._cache[key] = new_value
if invalidate_cache:
self._invalidate_cache_values(key)
def _invalidate_cache_values(self, key):
"""Invalidate all the cache dependencies of key.
Parameters
----------
key : str
The name of the cached whose dependencies must be invalidated.
"""
dependency_graph = self._get_cache_dependency_graph()
for k in dependency_graph[key]:
# Invalidate observables that are in a separate cache.
if k == 'observables':
for observable in self.observables:
self._computed_observables[observable] = None
else:
# Invalidate dependencies of k.
self._invalidate_cache_values(k)
# Remove k.
self._cache.pop(k, None)
# -------------------------------------------------------------------------
# Cached properties.
# -------------------------------------------------------------------------
max_n_iterations = CachedProperty('max_n_iterations', check_changes=True)
@max_n_iterations.validator
def max_n_iterations(self, instance, new_value):
"""The maximum allowed value for max_n_iterations is n_iterations."""
if new_value is None or new_value > instance.n_iterations:
new_value = instance.n_iterations
return new_value
# -------------------------------------------------------------------------
# Abstract methods.
# -------------------------------------------------------------------------
[docs] def read_energies(self):
"""
Extract energies from the ncfile and order them by replica, state, iteration.
Returns
-------
sampled_energy_matrix : np.ndarray of shape [n_replicas, n_states, n_iterations]
Potential energy matrix of the sampled states.
unsampled_energy_matrix : np.ndarray of shape [n_replicas, n_unsamped_states, n_iterations]
Potential energy matrix of the unsampled states.
Energy from each drawn sample n, evaluated at unsampled state l.
If no unsampled states were drawn, this will be shape (0,N).
neighborhoods : np.ndarray of shape [n_replicas, n_states, n_iterations]
Neighborhood energies were computed at, uses a boolean mask over the energy_matrix.
replica_state_indices : np.ndarray of shape [n_replicas, n_iterations]
States sampled by the replicas in the energy_matrix
"""
# TODO: should we keep it unified and always truncate to max_n_iterations?
return self._read_energies(truncate_max_n_iterations=False)
def _read_energies(self, truncate_max_n_iterations):
"""
Extract energies from the ncfile and order them by replica, state, iteration.
Optionally truncate the data to self.max_n_iterations.
"""
logger.info("Reading energies...")
# reporter_energies is [energy_sampled_states, neighborhoods, energy_unsampled_states].
energy_data = list(self._reporter.read_energies())
energy_data.append(self._reporter.read_replica_thermodynamic_states())
logger.info("Done.")
# Truncate the number of iterations to self.max_n_iterations if requested.
if truncate_max_n_iterations:
for i, energies in enumerate(energy_data):
# The +1 accounts for minimization iteration.
energy_data[i] = energies[:self.max_n_iterations+1]
# Convert from (n_iterations, n_replicas, n_states) to (n_replicas, n_states, n_iterations).
for i, energies in enumerate(energy_data):
energy_data[i] = np.moveaxis(energies, 0, -1)
# Unpack.
sampled_energy_matrix, neighborhoods, unsampled_energy_matrix, replicas_state_indices = energy_data
# TODO: Figure out what format we need the data in to be useful for both global and local MBAR/WHAM
# For now, we simply can't handle analysis of non-global calculations.
if np.any(neighborhoods == 0):
raise Exception('Non-global MBAR analysis not implemented yet.')
return sampled_energy_matrix, unsampled_energy_matrix, neighborhoods, replicas_state_indices
@property
def has_log_weights(self):
"""
Return True if the storage has log weights, False otherwise
"""
try:
# Check that logZ and log_weights have per-iteration data
# If either of these return a ValueError, then no history data are available
_ = self._reporter.read_logZ(0)
_ = self._reporter.read_online_analysis_data(0, 'log_weights')
return True
except ValueError:
return False
[docs] def read_log_weights(self):
"""
Extract log weights from the ncfile, if present.
Returns ValueError if not present.
Returns
-------
log_weights : np.ndarray of shape [n_states, n_iterations]
log_weights[l,n] is the log weight applied to state ``l``
during the collection of samples at iteration ``n``
"""
log_weights = np.array(
self._reporter.read_online_analysis_data(slice(None, None), 'log_weights')['log_weights'])
log_weights = np.moveaxis(log_weights, 0, -1)
return log_weights
[docs] def read_logZ(self, iteration=None):
"""
Extract logZ estimates from the ncfile, if present.
Returns ValueError if not present.
Parameters
----------
iteration : int or slice, optional, default=None
If specified, iteration or slice of iterations to extract
Returns
-------
logZ : np.ndarray of shape [n_states, n_iterations]
logZ[l,n] is the online logZ estimate for state ``l`` at iteration ``n``
"""
if iteration == -1:
log_z = self._reporter.read_logZ(iteration)
else:
if iteration is not None:
log_z = self._reporter.read_online_analysis_data(iteration, "logZ")["logZ"]
else:
log_z = self._reporter.read_online_analysis_data(slice(0, None), "logZ")["logZ"]
log_z = np.moveaxis(log_z, 0, -1)
return log_z
[docs] def get_effective_energy_timeseries(self, energies=None, replica_state_indices=None):
"""
Generate the effective energy (negative log deviance) timeseries that is generated for this phase
The effective energy for a series of samples x_n, n = 1..N, is defined as
u_n = - \ln \pi(x_n) + c
where \pi(x) is the probability density being sampled, and c is an arbitrary constant.
Parameters
----------
energies : ndarray of shape (K,L,N), optional, Default: None
Energies from replicas K, sampled states L, and iterations N
If provided, then states input_sampled_states must also be provided
replica_state_indices : ndarray of shape (K,N), optional, Default: None
Integer indices of each sampled state (matching L dimension in input_energy)
that each replica K sampled every iteration N.
If provided, then states input_energies must also be provided
Returns
-------
u_n : ndarray of shape (N,)
u_n[n] is the negative log deviance of the same from iteration ``n``
Timeseries used to determine equilibration time and statistical inefficiency.
"""
raise NotImplementedError("This class has not implemented this function")
# -------------------------------------------------------------------------
# MBAR routines.
# -------------------------------------------------------------------------
@staticmethod
# Private Class Methods
def _create_mbar(self, energy_matrix, samples_per_state):
"""
Initialize MBAR for Free Energy and Enthalpy estimates, this may take a while.
This function is helpful for those who want to create a slightly different mbar object with different
parameters.
This function is hidden from the user unless they really, really need to create their own mbar object
Parameters
----------
energy_matrix : array of numpy.float64, optional, default=None
Reduced potential energies of the replicas; if None, will be extracted from the ncfile
samples_per_state : array of ints, optional, default=None
Number of samples drawn from each kth state; if None, will be extracted from the ncfile
"""
# Initialize MBAR (computing free energy estimates, which may take a while)
logger.info("Computing free energy differences...")
self.mbar = MBAR(energy_matrix, samples_per_state, **self._extra_analysis_kwargs)
return self.mbar
# -------------------------------------------------------------------------
# Analysis combination.
# -------------------------------------------------------------------------
def _combine_phases(self, other, operator='+'):
"""
Workhorse function when creating a :class:`MultiPhaseAnalyzer` object by combining single
:class:`PhaseAnalyzer`s
"""
phases = [self]
names = []
signs = [self._sign]
# Reset self._sign
self._sign = '+'
if self.name is None:
names.append(utils.generate_phase_name(self.name, []))
else:
names.append(self.name)
if isinstance(other, MultiPhaseAnalyzer):
new_phases = other.phases
new_signs = other.signs
new_names = other.names
final_new_names = []
for name in new_names:
other_names = [n for n in new_names if n != name]
final_new_names.append(utils.generate_phase_name(name, other_names + names))
names.extend(final_new_names)
for new_sign in new_signs:
if operator != '+' and new_sign == '+':
signs.append('-')
else:
signs.append('+')
phases.extend(new_phases)
elif isinstance(other, PhaseAnalyzer):
names.append(utils.generate_phase_name(other.name, names))
if operator != '+' and other._sign == '+':
signs.append('-')
else:
signs.append('+')
# Reset the other's sign if it got set to negative
other._sign = '+'
phases.append(other)
else:
base_err = "cannot {} 'PhaseAnalyzer' and '{}' objects"
if operator == '+':
err = base_err.format('add', type(other))
else:
err = base_err.format('subtract', type(other))
raise TypeError(err)
phase_pass = {'phases': phases, 'signs': signs, 'names': names}
return MultiPhaseAnalyzer(phase_pass)
def __add__(self, other):
return self._combine_phases(other, operator='+')
def __sub__(self, other):
return self._combine_phases(other, operator='-')
def __neg__(self):
"""Internally handle the internal sign"""
if self._sign == '+':
self._sign = '-'
else:
self._sign = '+'
return self
[docs]class MultiStateSamplerAnalyzer(PhaseAnalyzer):
"""
The MultiStateSamplerAnalyzer is the analyzer for a simulation generated from a MultiStateSampler simulation,
implemented as an instance of the :class:`PhaseAnalyzer`.
Parameters
----------
unbias_restraint : bool, optional
If True and a radially-symmetric restraint was used in the simulation,
the analyzer will remove the bias introduced by the restraint by
reweighting each of the end-points to a state using a square-well
potential restraint.
restraint_energy_cutoff : float or 'auto', optional
When the restraint is unbiased, the analyzer discards all the samples
for which the restrain potential energy (in kT) is above this cutoff.
Effectively, this is equivalent to placing a hard wall potential at a
restraint distance such that the restraint potential energy is equal to
``restraint_energy_cutoff``.
If ``'auto'`` and ``restraint_distance_cutoff`` is ``None``, this will
be set to the 99.9-percentile of the distribution of the restraint energies
in the bound state.
restraint_distance_cutoff : simtk.unit.Quantity or 'auto', optional
When the restraint is unbiased, the analyzer discards all the samples
for which the distance between the restrained atoms is above this cutoff.
Effectively, this is equivalent to placing a hard wall potential at a
restraint distance ``restraint_distance_cutoff``.
If ``'auto'`` and ``restraint_energy_cutoff`` is not specified, this will
be set to the 99.9-percentile of the distribution of the restraint distances
in the bound state.
Attributes
----------
unbias_restraint
restraint_energy_cutoff
restraint_distance_cutoff
mbar
n_equilibration_iterations
statistical_inefficiency
See Also
--------
PhaseAnalyzer
"""
def __init__(self, *args, unbias_restraint=True, restraint_energy_cutoff='auto',
restraint_distance_cutoff='auto', **kwargs):
# super() calls clear() that initialize the cached variables.
super().__init__(*args, **kwargs)
# Cached values with dependencies.
self.unbias_restraint = unbias_restraint
self.restraint_energy_cutoff = restraint_energy_cutoff
self.restraint_distance_cutoff = restraint_distance_cutoff
# TODO use class syntax and add docstring after dropping python 3.5 support.
_MixingStatistics = NamedTuple('MixingStatistics', [
('transition_matrix', np.ndarray),
('eigenvalues', np.ndarray),
('statistical_inefficiency', np.ndarray)
])
[docs] def clear(self):
"""Reset all cached objects.
This must to be called if the information in the reporter changes
after analysis.
"""
# Reset cached values that are read directly from the Reporter.
# super() takes care of invalidating the cached properties.
super().clear()
self._radially_symmetric_restraint_data = None
self._restraint_energies = {}
self._restraint_distances = {}
[docs] def generate_mixing_statistics(self, number_equilibrated: Union[int, None] = None) -> NamedTuple:
"""
Compute and return replica mixing statistics.
Compute the transition state matrix, its eigenvalues sorted from
greatest to least, and the state index correlation function.
Parameters
----------
number_equilibrated : int, optional, default=None
If specified, only samples ``number_equilibrated:end`` will
be used in analysis. If not specified, automatically retrieves
the number from equilibration data or generates it from the
internal energy.
Returns
-------
mixing_statistics : namedtuple
A namedtuple containing the following attributes:
- ``transition_matrix``: (nstates by nstates ``np.array``)
- ``eigenvalues``: (nstates-dimensional ``np.array``)
- ``statistical_inefficiency``: float
"""
# Read data from disk
if number_equilibrated is None:
number_equilibrated = self.n_equilibration_iterations
states = self._reporter.read_replica_thermodynamic_states()
n_iterations, n_replicas = states.shape
n_states = self._reporter.n_states
n_ij = np.zeros([n_states, n_states], np.int64)
# Compute empirical transition count matrix.
for iteration in range(number_equilibrated, n_iterations - 1):
for i_replica in range(n_replicas):
i_state = states[iteration, i_replica]
j_state = states[iteration + 1, i_replica]
n_ij[i_state, j_state] += 1
# Compute transition matrix estimate.
# TODO: Replace with maximum likelihood reversible count estimator from msmbuilder or pyemma.
t_ij = np.zeros([n_states, n_states], np.float64)
for i_state in range(n_states):
# Cast to float to ensure we don't get integer division
denominator = float((n_ij[i_state, :].sum() + n_ij[:, i_state].sum()))
if denominator > 0:
for j_state in range(n_states):
t_ij[i_state, j_state] = (n_ij[i_state, j_state] + n_ij[j_state, i_state]) / denominator
else:
t_ij[i_state, i_state] = 1.0
# Estimate eigenvalues
mu = np.linalg.eigvals(t_ij)
mu = -np.sort(-mu) # Sort in descending order
# Compute state index statistical inefficiency of stationary data.
# states[n][k] is the state index of replica k at iteration n, but
# the functions wants a list of timeseries states[k][n].
states_kn = np.transpose(states[number_equilibrated:])
g = timeseries.statisticalInefficiencyMultiple(states_kn)
return self._MixingStatistics(transition_matrix=t_ij, eigenvalues=mu,
statistical_inefficiency=g)
[docs] def show_mixing_statistics(self, cutoff=0.05, number_equilibrated=None):
"""
Print summary of mixing statistics. Passes information off to generate_mixing_statistics then prints it out to
the logger
Parameters
----------
cutoff : float, optional, default=0.05
Only transition probabilities above 'cutoff' will be printed
number_equilibrated : int, optional, default=None
If specified, only samples number_equilibrated:end will be used in analysis
If not specified, it uses the internally held statistics best
"""
mixing_statistics = self.generate_mixing_statistics(number_equilibrated=number_equilibrated)
# Print observed transition probabilities.
nstates = mixing_statistics.transition_matrix.shape[1]
logger.info("Cumulative symmetrized state mixing transition matrix:")
str_row = "{:6s}".format("")
for jstate in range(nstates):
str_row += "{:6d}".format(jstate)
logger.info(str_row)
for istate in range(nstates):
str_row = ""
str_row += "{:-6d}".format(istate)
for jstate in range(nstates):
P = mixing_statistics.transition_matrix[istate, jstate]
if P >= cutoff:
str_row += "{:6.3f}".format(P)
else:
str_row += "{:6s}".format("")
logger.info(str_row)
# Estimate second eigenvalue and equilibration time.
perron_eigenvalue = mixing_statistics.eigenvalues[1]
if perron_eigenvalue >= 1:
logger.info('Perron eigenvalue is unity; Markov chain is decomposable.')
else:
equilibration_timescale = 1.0 / (1.0 - perron_eigenvalue)
logger.info('Perron eigenvalue is {0:.5f}; state equilibration timescale '
'is ~ {1:.1f} iterations'.format(perron_eigenvalue, equilibration_timescale)
)
# Print information about replica state index statistical efficiency.
logger.info('Replica state index statistical inefficiency is '
'{:.3f}'.format(mixing_statistics.statistical_inefficiency))
def _get_radially_symmetric_restraint_data(self):
"""Return the radially-symmetric restraint force used in the bound state.
Returns
-------
restraint_force : openmmtools.forces.RadiallySymmetricRestraintForce
The restraint force used in the bound state.
weights_group1 : list of simtk.unit.Quantity
The masses of the restrained atoms in the first centroid group.
weights_group2 : list of simtk.unit.Quantity
The masses of the restrained atoms in the second centroid group.
Raises
------
TypeError
If the end states don't have lambda_restraints set to 1.
openmmtools.forces.NoForceFoundError
If there are no radially-symmetric restraints in the bound state.
"""
# Check cached value.
if self._radially_symmetric_restraint_data is not None:
return self._radially_symmetric_restraint_data
# Isolate the end states.
end_states = self._get_end_thermodynamic_states()
# Isolate restraint force.
system = end_states[0].system
restraint_parent_class = mmtools.forces.RadiallySymmetricRestraintForce
# This raises mmtools.forces.NoForceFoundError if there's no restraint to unbias.
force_idx, restraint_force = mmtools.forces.find_forces(system, force_type=restraint_parent_class,
only_one=True, include_subclasses=True)
# The force is owned by the System, we have to copy to avoid the memory to be deallocated.
restraint_force = copy.deepcopy(restraint_force)
# Check that the restraint was turned on at the end states.
if end_states[0].lambda_restraints != 1.0 or end_states[-1].lambda_restraints != 1.0:
raise TypeError('Cannot unbias a restraint that is turned off at one of the end states.')
# Read the centroid weights (mass) of the restrained particles.
weights_group1 = [system.getParticleMass(i) for i in restraint_force.restrained_atom_indices1]
weights_group2 = [system.getParticleMass(i) for i in restraint_force.restrained_atom_indices2]
# Cache value so that we won't have to deserialize the system again.
self._radially_symmetric_restraint_data = restraint_force, weights_group1, weights_group2
return self._radially_symmetric_restraint_data
# -------------------------------------------------------------------------
# MBAR creation.
# -------------------------------------------------------------------------
[docs] def get_effective_energy_timeseries(self, energies=None, replica_state_indices=None):
"""
Generate the effective energy (negative log deviance) timeseries that is generated for this phase.
The effective energy for a series of samples x_n, n = 1..N, is defined as
u_n = - \ln \pi(x_n) + c
where \pi(x) is the probability density being sampled, and c is an arbitrary constant.
Parameters
----------
energies : ndarray of shape (K,L,N), optional, Default: None
Energies from replicas K, sampled states L, and iterations N.
If provided, then states input_sampled_states must also be provided.
replica_state_indices : ndarray of shape (K,N), optional, Default: None
Integer indices of each sampled state (matching L dimension in input_energy).
that each replica K sampled every iteration N.
If provided, then states input_energies must also be provided.
Returns
-------
u_n : ndarray of shape (N,)
u_n[n] is the negative log deviance of the same from iteration ``n``
Timeseries used to determine equilibration time and statistical inefficiency.
"""
if energies is None and replica_state_indices is None:
# Case where no input is provided
energies, _, _, replica_state_indices = self._read_energies(truncate_max_n_iterations=True)
elif (energies is not None) != (replica_state_indices is not None):
# XOR operator
raise ValueError("If input_energy or input_sampled_states are provided, "
"then the other cannot be None due to ambiguity!")
n_replicas, n_states, n_iterations = energies.shape
# Check for log weights
has_log_weights = False
if self.has_log_weights:
has_log_weights = True
log_weights = self.read_log_weights()
f_l = - self.read_logZ(iteration=-1) # use last (best) estimate of free energies
u_n = np.zeros([n_iterations], np.float64)
# Slice of all replicas, have to use this as : is too greedy
replicas_slice = range(n_replicas)
for iteration in range(n_iterations):
# Slice the current sampled states by those replicas.
states_slice = replica_state_indices[:, iteration]
u_n[iteration] = np.sum(energies[replicas_slice, states_slice, iteration])
# Correct for potentially-changing log weights
if has_log_weights:
u_n[iteration] += - np.sum(log_weights[states_slice, iteration]) + (
n_replicas * logsumexp(-f_l[:] + log_weights[:, iteration]))
return u_n
def _compute_mbar_decorrelated_energies(self):
"""Return an MBAR-ready decorrelated energy matrix.
The data is returned after discarding equilibration and truncating
the iterations to self.max_n_iterations.
Returns
-------
energy_matrix : energy matrix of shape (K,N) indexed by k,n
K is the total number of states observables are desired.
N is the total number of samples drawn from ALL states.
The nth configuration is the energy evaluated in the kth thermodynamic state.
samples_per_state : 1-D iterable of shape K
The number of samples drawn from each kth state.
The \sum samples_per_state = N.
"""
# energy_data is [energy_sampled, energy_unsampled, neighborhood, replicas_state_indices]
energy_data = list(self._read_energies(truncate_max_n_iterations=True))
# Use the cached information to generate the equilibration data.
sampled_energy_matrix, unsampled_energy_matrix, neighborhood, replicas_state_indices = energy_data
number_equilibrated, g_t, Neff_max = self._get_equilibration_data(sampled_energy_matrix,
replicas_state_indices)
for i, energies in enumerate(energy_data):
# Discard equilibration iterations.
energies = utils.remove_unequilibrated_data(energies, number_equilibrated, -1)
# Subsample along the decorrelation data.
energy_data[i] = utils.subsample_data_along_axis(energies, g_t, -1)
sampled_energy_matrix, unsampled_energy_matrix, neighborhood, replicas_state_indices = energy_data
# Initialize the MBAR matrices in ln form.
n_replicas, n_sampled_states, n_iterations = sampled_energy_matrix.shape
_, n_unsampled_states, _ = unsampled_energy_matrix.shape
n_total_states = n_sampled_states + n_unsampled_states
energy_matrix = np.zeros([n_total_states, n_iterations*n_replicas])
samples_per_state = np.zeros([n_total_states], dtype=int)
# Compute shift index for how many unsampled states there were.
# This assume that we set an equal number of unsampled states at the end points.
first_sampled_state = int(n_unsampled_states/2.0)
last_sampled_state = n_total_states - first_sampled_state
# Cast the sampled energy matrix from kln' to ln form.
energy_matrix[first_sampled_state:last_sampled_state, :] = self.reformat_energies_for_mbar(sampled_energy_matrix)
# Determine how many samples and which states they were drawn from.
unique_sampled_states, counts = np.unique(replicas_state_indices, return_counts=True)
# Assign those counts to the correct range of states.
samples_per_state[first_sampled_state:last_sampled_state][unique_sampled_states] = counts
# Add energies of unsampled states to the end points.
if n_unsampled_states > 0:
energy_matrix[[0, -1], :] = self.reformat_energies_for_mbar(unsampled_energy_matrix)
logger.info("Found expanded cutoff states in the energies!")
logger.info("Free energies will be reported relative to them instead!")
# These cached values speed up considerably the computation of the
# free energy profile along the restraint distance/energy cutoff.
self._decorrelated_u_ln = energy_matrix
self._decorrelated_N_l = samples_per_state
return self._decorrelated_u_ln, self._decorrelated_N_l
def _compute_mbar_unbiased_energies(self):
"""Unbias the restraint, and apply restraint energy/distance cutoffs.
When there is a restraint to unbias, the function adds two extra unbiased
states at the end points of the energy matrix. Otherwise, the return value
is identical to self._compute_mbar_decorrelated_energies().
Returns
-------
unbiased_decorrelated_u_ln : np.array
A n_states x (n_sampled_states * n_unbiased_decorrelated_iterations)
array of energies (in kT), where n_unbiased_decorrelated_iterations
is generally <= n_decorrelated_iterations whe a restraint cutoff is
set.
unbiased_decorrelated_N_l : np.array
The total number of samples drawn from each state (including the
unbiased states).
"""
# Check if we need to unbias the restraint.
unbias_restraint = self.unbias_restraint
if unbias_restraint:
try:
restraint_data = self._get_radially_symmetric_restraint_data()
except (TypeError, mmtools.forces.NoForceFoundError) as e:
# If we don't need to unbias the restraint there's nothing else to do.
logger.info(str(e) + ' The restraint will not be unbiased.')
unbias_restraint = False
if not unbias_restraint:
self._unbiased_decorrelated_u_ln = self._decorrelated_u_ln
self._unbiased_decorrelated_N_l = self._decorrelated_N_l
return self._unbiased_decorrelated_u_ln, self._unbiased_decorrelated_N_l
# Compute the restraint energies/distances.
restraint_force, weights_group1, weights_group2 = restraint_data
logger.debug('Found {} restraint. The restraint will be unbiased.'.format(restraint_force.__class__.__name__))
logger.debug('Receptor restrained atoms: {}'.format(restraint_force.restrained_atom_indices1))
logger.debug('ligand restrained atoms: {}'.format(restraint_force.restrained_atom_indices2))
# Compute restraint energies/distances.
energies_ln, distances_ln = self._compute_restraint_energies(restraint_force, weights_group1,
weights_group2)
# Convert energies to kT unit for comparison to energy cutoff.
energies_ln = energies_ln / self.kT
logger.debug('Restraint energy mean: {} kT; std: {} kT'
''.format(np.mean(energies_ln), np.std(energies_ln, ddof=1)))
# Don't modify the cached decorrelated energies.
u_ln = copy.deepcopy(self._decorrelated_u_ln)
N_l = copy.deepcopy(self._decorrelated_N_l)
n_decorrelated_iterations_ln = u_ln.shape[1]
assert len(energies_ln) == n_decorrelated_iterations_ln, '{}, {}'.format(energies_ln.shape, u_ln.shape)
assert len(self._decorrelated_state_indices_ln) == n_decorrelated_iterations_ln
# Determine the cutoffs to use for the simulations.
restraint_energy_cutoff, restraint_distance_cutoff = self._get_restraint_cutoffs()
apply_energy_cutoff = restraint_energy_cutoff is not None
apply_distance_cutoff = restraint_distance_cutoff is not None
# We need to take into account the initial unsampled states to index correctly N_l.
n_unsampled_states = len(u_ln) - self.n_states
first_sampled_state_idx = int(n_unsampled_states / 2)
# Determine which samples are outside the cutoffs or have to be truncated.
columns_to_keep = []
for iteration_ln_idx, state_idx in enumerate(self._decorrelated_state_indices_ln):
if ((apply_energy_cutoff and energies_ln[iteration_ln_idx] > restraint_energy_cutoff) or
(apply_distance_cutoff and distances_ln[iteration_ln_idx] > restraint_distance_cutoff)):
# Update the number of samples generated from its state.
N_l[state_idx + first_sampled_state_idx] -= 1
else:
columns_to_keep.append(iteration_ln_idx)
# Drop all columns that exceed the cutoff(s).
n_discarded = n_decorrelated_iterations_ln - len(columns_to_keep)
logger.debug('Discarding {}/{} samples outside the cutoffs (restraint_distance_cutoff: {}, '
'restraint_energy_cutoff: {}).'.format(n_discarded, n_decorrelated_iterations_ln,
restraint_distance_cutoff,
restraint_energy_cutoff))
u_ln = u_ln[:, columns_to_keep]
# Add new end states that don't include the restraint.
energies_ln = energies_ln[columns_to_keep]
n_states, n_iterations = u_ln.shape
n_states_new = n_states + 2
N_l_new = np.zeros(n_states_new, N_l.dtype)
u_ln_new = np.zeros((n_states_new, n_iterations), u_ln.dtype)
u_ln_new[0, :] = u_ln[0] - energies_ln
u_ln_new[-1, :] = u_ln[-1] - energies_ln
# Copy old values.
N_l_new[1:-1] = N_l
u_ln_new[1:-1, :] = u_ln
# Cache new values.
self._unbiased_decorrelated_u_ln = u_ln_new
self._unbiased_decorrelated_N_l = N_l_new
return self._unbiased_decorrelated_u_ln, self._unbiased_decorrelated_N_l
def _compute_restraint_energies(self, restraint_force, weights_group1, weights_group2):
"""Compute the restrain energies and distances for the uncorrelated iterations.
Parameters
----------
restraint_force : openmmtools.forces.RadiallySymmetricRestraintForce
The restraint force.
weights_group1 : list of float
The mass of the particle in the first CustomCentroidBondForce group.
weights_group2 : list of float
The mass of the particles in the second CustomCentroidBondForce group.
Returns
-------
restraint_energies_ln : simtk.unit.Quantity
A (n_sampled_states * n_decorrelated_iterations)-long array with
the restrain energies (units of energy/mole).
restraint_distances_ln : simtk.unit.Quantity or None
If we are not applying a distance cutoff, this is None. Otherwise,
a (n_sampled_states * n_decorrelated_iterations)-long array with
the restrain distances (units of length) for each frame.
"""
decorrelated_iterations = self._decorrelated_iterations # Shortcut.
decorrelated_iterations_set = set(decorrelated_iterations)
# Determine total number of energies/distances to compute.
# The +1 is for the minimization iteration.
n_frames_ln = self.n_replicas * len(decorrelated_iterations)
# Computing the restraint energies/distances is expensive and we
# don't want to recompute everything when _decorrelated_iterations
# changes (e.g. when max_n_iterations changes) so we keep the cached
# values of the iterations we have computed.
# The dictionary instead of a masked array is for memory efficiency
# since the matrix will be very sparse (especially with SAMS).
def extract_decorrelated(cached_dict, dtype, unit):
if not decorrelated_iterations_set.issubset(set(cached_dict)):
return None
decorrelated = np.zeros(n_frames_ln, dtype=dtype)
for replica_idx in range(self.n_replicas):
for iteration_idx, iteration in enumerate(decorrelated_iterations):
frame_idx = replica_idx*len(decorrelated_iterations) + iteration_idx
decorrelated[frame_idx] = cached_dict[iteration][replica_idx]
return decorrelated * unit
# We compute the distances only if we are using a distance cutoff.
_, compute_distances = self._get_use_restraint_cutoff()
# Check cached values.
if compute_distances and decorrelated_iterations_set.issubset(set(self._restraint_distances)):
compute_distances = False
if decorrelated_iterations_set.issubset(set(self._restraint_energies)) and not compute_distances:
return (extract_decorrelated(self._restraint_energies, dtype=np.float64, unit=_OPENMM_ENERGY_UNIT),
extract_decorrelated(self._restraint_distances, dtype=np.float32, unit=_MDTRAJ_DISTANCE_UNIT))
# Don't modify the original restraint force.
restraint_force = copy.deepcopy(restraint_force)
is_periodic = restraint_force.usesPeriodicBoundaryConditions()
# Store the original indices of the restrained atoms.
original_restrained_atom_indices1 = restraint_force.restrained_atom_indices1
original_restrained_atom_indices2 = restraint_force.restrained_atom_indices2
original_restrained_atom_indices = (original_restrained_atom_indices1 +
original_restrained_atom_indices2)
# Create new system with only solute and restraint forces.
reduced_system = openmm.System()
for weight in weights_group1 + weights_group2:
reduced_system.addParticle(weight)
# Adapt the restraint force atom indices to the reduced system.
n_atoms1 = len(weights_group1)
n_atoms = n_atoms1 + len(weights_group2)
restraint_force.restrained_atom_indices1 = list(range(n_atoms1))
restraint_force.restrained_atom_indices2 = list(range(n_atoms1, n_atoms))
reduced_system.addForce(restraint_force)
# If we need to image the molecule, we need an MDTraj trajectory.
if compute_distances and is_periodic:
# Create topology with only the restrained atoms.
serialized_topography = self._reporter.read_dict('metadata/topography')
topography = mmtools.utils.deserialize(serialized_topography)
topology = topography.topology
topology = topology.subset(self._reporter.analysis_particle_indices)
# Use the receptor as an anchor molecule and image the ligand.
anchor_molecules = [{a for a in topology.atoms if a.index in set(topography.receptor_atoms)}]
imaged_molecules = [{a for a in topology.atoms if a.index in set(topography.ligand_atoms)}]
# Initialize trajectory object needed for imaging molecules.
trajectory = mdtraj.Trajectory(xyz=np.zeros((topology.n_atoms, 3)), topology=topology)
# Create context used to compute the energies.
integrator = openmm.VerletIntegrator(1.0*units.femtosecond)
context = openmm.Context(reduced_system, integrator)
# TODO: we need to provide a reporter generator to iterate over single
# TODO: iterations but reading automatically one chunksize at a time.
# chunk_size = self._reporter.checkpoint_interval
# iterations_groups = itertools.groupby(enumerate(decorrelated_iterations), key=lambda x: int(x[1] / chunk_size))
# Pre-computing energies/distances.
logger.debug('Computing restraint energies/distances...')
for iteration_idx, iteration in enumerate(decorrelated_iterations):
# Check if we have already computed this energy/distance.
if (iteration in self._restraint_energies and
(not compute_distances or iteration in self._restraint_distances)):
continue
self._restraint_energies[iteration] = {}
if compute_distances:
self._restraint_distances[iteration] = {}
# Read sampler states only if we haven't computed this iteration yet.
# Obtain solute only sampler states.
sampler_states = self._reporter.read_sampler_states(iteration=iteration,
analysis_particles_only=True)
for replica_idx, sampler_state in enumerate(sampler_states):
sliced_sampler_state = sampler_state[original_restrained_atom_indices]
sliced_sampler_state.apply_to_context(context)
potential_energy = context.getState(getEnergy=True).getPotentialEnergy()
self._restraint_energies[iteration][replica_idx] = potential_energy / _OPENMM_ENERGY_UNIT
if compute_distances:
# Check if an analytical solution is available.
try:
distance = restraint_force.distance_at_energy(potential_energy) / _MDTRAJ_DISTANCE_UNIT
except (NotImplementedError, ValueError):
if is_periodic:
# Update trajectory positions/box vectors.
trajectory.xyz = (sampler_state.positions / _MDTRAJ_DISTANCE_UNIT).astype(np.float32)
trajectory.unitcell_vectors = np.array([sampler_state.box_vectors / _MDTRAJ_DISTANCE_UNIT],
dtype=np.float32)
trajectory.image_molecules(inplace=True, anchor_molecules=anchor_molecules,
other_molecules=imaged_molecules)
positions_group1 = trajectory.xyz[0][original_restrained_atom_indices1]
positions_group2 = trajectory.xyz[0][original_restrained_atom_indices2]
else:
positions_group1 = sampler_state.positions[original_restrained_atom_indices1]
positions_group2 = sampler_state.positions[original_restrained_atom_indices2]
positions_group1 /= _MDTRAJ_DISTANCE_UNIT
positions_group2 /= _MDTRAJ_DISTANCE_UNIT
# Set output arrays.
distance = compute_centroid_distance(positions_group1, positions_group2,
weights_group1, weights_group2)
self._restraint_distances[iteration][replica_idx] = distance
return (extract_decorrelated(self._restraint_energies, dtype=np.float64, unit=_OPENMM_ENERGY_UNIT),
extract_decorrelated(self._restraint_distances, dtype=np.float32, unit=_MDTRAJ_DISTANCE_UNIT))
def _get_use_restraint_cutoff(self):
"""Determine if we need to apply a cutoff on the restraint energies and/or distances."""
apply_distance_cutoff = isinstance(self.restraint_distance_cutoff, units.Quantity)
apply_energy_cutoff = isinstance(self.restraint_energy_cutoff, float)
# When both cutoffs are auto, use distance cutoff.
if self.restraint_distance_cutoff == 'auto' and not apply_energy_cutoff:
apply_distance_cutoff = True
elif self.restraint_energy_cutoff == 'auto' and self.restraint_distance_cutoff is None:
apply_energy_cutoff = True
return apply_energy_cutoff, apply_distance_cutoff
def _get_restraint_energies_distances_at_state(self, state_idx, get_energies=True, get_distances=True):
"""Return the restraint energies and distances for a single state."""
# Resolve negative indices.
if state_idx < 0:
state_idx = self.n_states + state_idx
replica_state_indices = self._reporter.read_replica_thermodynamic_states()
# Gather the state restraint energies/distances.
state_energies = [] if get_energies else None
state_distances = [] if get_distances else None
for state_data, cached_data in [(state_energies, self._restraint_energies),
(state_distances, self._restraint_distances)]:
if state_data is None:
continue
for iteration, states_data in cached_data.items():
# Find the replicas in this state.
replica_indices = np.where(replica_state_indices[iteration] == state_idx)[0]
for replica_idx in replica_indices:
state_data.append(states_data[replica_idx])
# Convert to the correct units.
if state_energies is not None:
state_energies = np.array(state_energies) * _OPENMM_ENERGY_UNIT / self.kT
if state_distances is not None:
state_distances = np.array(state_distances) * _MDTRAJ_DISTANCE_UNIT
return state_energies, state_distances
def _determine_automatic_restraint_cutoff(self, compute_energy_cutoff=True, compute_distance_cutoff=True):
"""Automatically determine the restraint cutoffs.
This must be called after _compute_restraint_energies(). The cutoffs are
determine as the 99.9%-percentile of the distribution of the restraint
energies/distances in the bound state.
"""
# Gather the bound state restraint energies/distances.
state0_energies, state0_distances = self._get_restraint_energies_distances_at_state(
state_idx=0, get_energies=compute_energy_cutoff, get_distances=compute_distance_cutoff)
# Compute cutoff as the 99.9%-percentile of the energies/distances distributions.
energy_cutoff = None
distance_cutoff = None
err_msg = ('Thermodynamic state 0 has not been sampled enough to '
'determine automatically the restraint {} cutoff.')
if compute_energy_cutoff:
if len(state0_energies) == 0:
raise InsufficientData(err_msg.format('energy'))
energy_cutoff = np.percentile(state0_energies, 99.9)
if compute_distance_cutoff:
if len(state0_distances) == 0:
raise InsufficientData(err_msg.format('distance'))
state0_distances /= _MDTRAJ_DISTANCE_UNIT
distance_cutoff = np.percentile(state0_distances, 99.9) * _MDTRAJ_DISTANCE_UNIT
return energy_cutoff, distance_cutoff
def _get_restraint_cutoffs(self):
"""Return the restraint energies and distance cutoff to be used for unbiasing."""
apply_energy_cutoff, apply_distance_cutoff = self._get_use_restraint_cutoff()
# Determine automatically the restraint distance cutoff is necessary.
if apply_distance_cutoff and self.restraint_distance_cutoff == 'auto':
_, restraint_distance_cutoff = self._determine_automatic_restraint_cutoff(compute_energy_cutoff=False)
logger.debug('Chosen automatically a restraint distance cutoff of {}'.format(restraint_distance_cutoff))
elif self.restraint_distance_cutoff == 'auto':
restraint_distance_cutoff = None
else:
restraint_distance_cutoff = self.restraint_distance_cutoff
# Determine automatically the restraint energy cutoff is necessary.
if apply_energy_cutoff and self.restraint_energy_cutoff == 'auto':
restraint_energy_cutoff, _ = self._determine_automatic_restraint_cutoff(compute_distance_cutoff=False)
logger.debug('Chosen automatically a restraint energy cutoff of {}kT'.format(restraint_energy_cutoff))
elif self.restraint_energy_cutoff == 'auto':
restraint_energy_cutoff = None
else:
restraint_energy_cutoff = self.restraint_energy_cutoff
return restraint_energy_cutoff, restraint_distance_cutoff
# -------------------------------------------------------------------------
# Observables.
# -------------------------------------------------------------------------
def _compute_free_energy(self):
"""
Estimate free energies of all alchemical states.
"""
nstates = self.mbar.N_k.size
# Get matrix of dimensionless free energy differences and uncertainty estimate.
logger.info("Computing covariance matrix...")
try:
# pymbar 2
(Deltaf_ij, dDeltaf_ij) = self.mbar.getFreeEnergyDifferences()
except ValueError:
# pymbar 3
(Deltaf_ij, dDeltaf_ij, _) = self.mbar.getFreeEnergyDifferences()
# Matrix of free energy differences
logger.info("Deltaf_ij:")
for i in range(nstates):
str_row = ""
for j in range(nstates):
str_row += "{:8.3f}".format(Deltaf_ij[i, j])
logger.info(str_row)
# Matrix of uncertainties in free energy difference (expectations standard
# deviations of the estimator about the true free energy)
logger.info("dDeltaf_ij:")
for i in range(nstates):
str_row = ""
for j in range(nstates):
str_row += "{:8.3f}".format(dDeltaf_ij[i, j])
logger.info(str_row)
# Return free energy differences and an estimate of the covariance.
free_energy_dict = {'value': Deltaf_ij, 'error': dDeltaf_ij}
self._computed_observables['free_energy'] = free_energy_dict
[docs] def get_free_energy(self):
"""
Compute the free energy and error in free energy from the MBAR object
Output shape changes based on if there are unsampled states detected in the sampler
Returns
-------
DeltaF_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
Difference in free energy from each state relative to each other state
dDeltaF_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
Error in the difference in free energy from each state relative to each other state
"""
if self._computed_observables['free_energy'] is None:
self._compute_free_energy()
free_energy_dict = self._computed_observables['free_energy']
return free_energy_dict['value'], free_energy_dict['error']
def _compute_enthalpy_and_entropy(self):
"""Function to compute the cached values of enthalpy and entropy"""
(f_k, df_k, H_k, dH_k, S_k, dS_k) = self.mbar.computeEntropyAndEnthalpy()
enthalpy = {'value': H_k, 'error': dH_k}
entropy = {'value': S_k, 'error': dS_k}
self._computed_observables['enthalpy'] = enthalpy
self._computed_observables['entropy'] = entropy
[docs] def get_enthalpy(self):
"""
Compute the difference in enthalpy and error in that estimate from the MBAR object
Output shape changes based on if there are unsampled states detected in the sampler
Returns
-------
DeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
Difference in enthalpy from each state relative to each other state
dDeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
Error in the difference in enthalpy from each state relative to each other state
"""
if self._computed_observables['enthalpy'] is None:
self._compute_enthalpy_and_entropy()
enthalpy_dict = self._computed_observables['enthalpy']
return enthalpy_dict['value'], enthalpy_dict['error']
[docs] def get_entropy(self):
"""
Compute the difference in entropy and error in that estimate from the MBAR object
Output shape changes based on if there are unsampled states detected in the sampler
Returns
-------
DeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
Difference in enthalpy from each state relative to each other state
dDeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
Error in the difference in enthalpy from each state relative to each other state
"""
if self._computed_observables['entropy'] is None:
self._compute_enthalpy_and_entropy()
entropy_dict = self._computed_observables['entropy']
return entropy_dict['value'], entropy_dict['error']
def _get_equilibration_data(self, energies=None, replica_state_indices=None):
"""Generate the equilibration data from best practices.
Parameters
----------
energies : ndarray of shape (K,L,N), optional, Default: None
Energies from replicas K, sampled states L, and iterations N.
If provided, then replica_state_indices must also be provided.
replica_state_indices : ndarray of shape (K,N), optional, Default: None
Integer indices of each sampled state (matching L dimension in input_energy).
that each replica K sampled every iteration N.
If provided, then states input_energies must also be provided.
Returns
-------
n_equilibration_iterations : int
statistical_inefficiency : float
n_uncorrelated_iterations : int
"""
u_n = self.get_effective_energy_timeseries(energies, replica_state_indices)
# Discard equilibration samples.
# TODO: if we include u_n[0] (the energy right after minimization) in the equilibration detection,
# TODO: then number_equilibrated is 0. Find a better way than just discarding first frame.
equilibration_data = list(utils.get_equilibration_data(u_n[1:]))
# Discard also minimization frame.
equilibration_data[0] += 1
self._equilibration_data = tuple(equilibration_data)
logger.debug('Equilibration data: {}'.format(equilibration_data))
return self._equilibration_data
# -------------------------------------------------------------------------
# Cached properties.
# -------------------------------------------------------------------------
unbias_restraint = CachedProperty('unbias_restraint', check_changes=True)
restraint_energy_cutoff = CachedProperty('restraint_energy_cutoff', check_changes=True)
restraint_distance_cutoff = CachedProperty('restraint_distance_cutoff', check_changes=True)
_equilibration_data = CachedProperty(
name='equilibration_data',
dependencies=['reporter', 'max_n_iterations'],
check_changes=True,
)
@_equilibration_data.default
def _equilibration_data(self, instance):
return instance._get_equilibration_data()
_decorrelated_state_indices_ln = CachedProperty(
name='decorrelated_state_indices_ln',
dependencies=['equilibration_data'],
)
@_decorrelated_state_indices_ln.default
def _decorrelated_state_indices_ln(self, instance):
"""Compute the replica thermodynamic state indices in ln formats."""
decorrelated_iterations = instance._decorrelated_iterations # Shortcut.
replica_state_indices = instance._reporter.read_replica_thermodynamic_states()
n_correlated_iterations, instance._n_replicas = replica_state_indices.shape
# Initialize output array.
n_frames = instance.n_replicas * len(decorrelated_iterations)
decorrelated_state_indices_ln = np.zeros(n_frames, dtype=np.int32)
# Map ln columns to the state.
for iteration_idx, iteration in enumerate(decorrelated_iterations):
for replica_idx in range(instance.n_replicas):
frame_idx = replica_idx*len(decorrelated_iterations) + iteration_idx
# Set output array.
state_idx = replica_state_indices[iteration, replica_idx]
decorrelated_state_indices_ln[frame_idx] = state_idx
instance._decorrelated_state_indices_ln = decorrelated_state_indices_ln
return decorrelated_state_indices_ln
_decorrelated_u_ln = CachedProperty(
name='decorrelated_u_ln',
dependencies=['equilibration_data'],
)
@_decorrelated_u_ln.default
def _decorrelated_u_ln(self, instance):
return instance._compute_mbar_decorrelated_energies()[0]
_decorrelated_N_l = CachedProperty(
name='decorrelated_N_l',
dependencies=['equilibration_data'],
)
@_decorrelated_N_l.default
def _decorrelated_N_l(self, instance):
return instance._compute_mbar_decorrelated_energies()[1]
_unbiased_decorrelated_u_ln = CachedProperty(
name='unbiased_decorrelated_u_ln',
dependencies=['unbias_restraint', 'restraint_energy_cutoff', 'restraint_distance_cutoff',
'decorrelated_state_indices_ln', 'decorrelated_u_ln', 'decorrelated_N_l'],
)
@_unbiased_decorrelated_u_ln.default
def _unbiased_decorrelated_u_ln(self, instance):
return instance._compute_mbar_unbiased_energies()[0]
_unbiased_decorrelated_N_l = CachedProperty(
name='unbiased_decorrelated_N_l',
dependencies=['unbias_restraint', 'restraint_energy_cutoff', 'restraint_distance_cutoff',
'decorrelated_state_indices_ln', 'decorrelated_u_ln', 'decorrelated_N_l'],
)
@_unbiased_decorrelated_N_l.default
def _unbiased_decorrelated_N_l(self, instance):
return instance._compute_mbar_unbiased_energies()[1]
mbar = CachedProperty(
name='mbar',
dependencies=['unbiased_decorrelated_u_ln', 'unbiased_decorrelated_N_l'],
)
@mbar.default
def mbar(self, instance):
return instance._create_mbar(instance._unbiased_decorrelated_u_ln,
instance._unbiased_decorrelated_N_l)
# -------------------------------------------------------------------------
# Dynamic properties.
# -------------------------------------------------------------------------
@property
def n_equilibration_iterations(self):
"""int: The number of equilibration interations."""
return self._equilibration_data[0]
@property
def statistical_inefficiency(self):
"""float: The statistical inefficiency of the sampler."""
return self._equilibration_data[1]
@property
def _decorrelated_iterations(self):
"""list of int: the indices of the decorrelated iterations truncated to max_n_iterations."""
equilibrium_iterations = np.array(range(self.n_equilibration_iterations, self.max_n_iterations + 1))
decorrelated_iterations_indices = timeseries.subsampleCorrelatedData(equilibrium_iterations,
self.statistical_inefficiency)
return equilibrium_iterations[decorrelated_iterations_indices]
# https://choderalab.slack.com/files/levi.naden/F4G6L9X8S/quick_diagram.png
[docs]class MultiPhaseAnalyzer(object):
"""
Multiple Phase Analyzer creator, not to be directly called itself, but instead called by adding or subtracting
different implemented :class:`PhaseAnalyzer` or other :class:`MultiPhaseAnalyzers`'s. The individual Phases of
the :class:`MultiPhaseAnalyzer` are only references to existing Phase objects, not copies. All
:class:`PhaseAnalyzer` and :class:`MultiPhaseAnalyzer` classes support ``+`` and ``-`` operations.
The observables of this phase are determined through inspection of all the passed in phases and only observables
which are shared can be computed. For example:
``PhaseA`` has ``.get_free_energy`` and ``.get_entropy``
``PhaseB`` has ``.get_free_energy`` and ``.get_enthalpy``,
``PhaseAB = PhaseA + PhaseB`` will only have a ``.get_free_energy`` method
Because each Phase may have a different number of states, the ``reference_states`` property of each phase
determines which states from each phase to read the data from.
For observables defined by two states, the i'th and j'th reference states are used:
If we define ``PhaseAB = PhaseA - PhaseB``
Then ``PhaseAB.get_free_energy()`` is roughly equivalent to doing the following:
``A_i, A_j = PhaseA.reference_states``
``B_i, B_j = PhaseB.reference_states``
``PhaseA.get_free_energy()[A_i, A_j] - PhaseB.get_free_energy()[B_i, B_j]``
The above is not exact since get_free_energy returns an error estimate as well
For observables defined by a single state, only the i'th reference state is used
Given ``PhaseAB = PhaseA + PhaseB``, ``PhaseAB.get_temperature()`` is equivalent to:
``A_i = PhaseA.reference_states[0]``
``B_i = PhaseB.reference_states[0]``
``PhaseA.get_temperature()[A_i] + PhaseB.get_temperature()[B_i]``
For observables defined entirely by the phase, no reference states are needed.
Given ``PhaseAB = PhaseA + PhaseB``, ``PhaseAB.get_standard_state_correction()`` gives:
``PhaseA.get_standard_state_correction() + PhaseB.get_standard_state_correction()``
Each phase MUST use the same ObservablesRegistry, otherwise an error is raised
This class is public to see its API.
Parameters
----------
phases : dict
has keys "phases", "names", and "signs"
Attributes
----------
observables
phases
names
signs
registry
See Also
--------
PhaseAnalyzer
ObservablesRegistry
"""
def __init__(self, phases):
"""
Create the compound phase which is any combination of phases to generate a new MultiPhaseAnalyzer.
"""
# Compare ObservableRegistries
ref_registry = phases['phases'][0].registry
for phase in phases['phases'][1:]:
# Use is comparison since we are checking same insetance
if phase.registry is not ref_registry:
raise ValueError("Not all phases have the same ObservablesRegistry! Observable calculation "
"will be inconsistent!")
self.registry = ref_registry
# Determine available observables
observables = []
for observable in self.registry.observables:
shared_observable = True
for phase in phases['phases']:
if observable not in phase.observables:
shared_observable = False
break
if shared_observable:
observables.append(observable)
if len(observables) == 0:
raise RuntimeError("There are no shared computable observable between the phases, combining them will do "
"nothing.")
self._observables = tuple(observables)
self._phases = phases['phases']
self._names = phases['names']
self._signs = phases['signs']
# Set the methods shared between both objects
for observable in self.observables:
setattr(self, "get_" + observable, self._spool_function(observable))
def _spool_function(self, observable):
"""
Dynamic observable calculator layer
Must be in its own function to isolate the variable name space
If you have this in the __init__, the "observable" variable colides with any others in the list, causing a
the wrong property to be fetched.
"""
return lambda: self._compute_observable(observable)
@property
def observables(self):
"""List of observables this :class:`MultiPhaseAnalyzer` can generate"""
return self._observables
@property
def phases(self):
"""List of implemented :class:`PhaseAnalyzer`'s objects this :class:`MultiPhaseAnalyzer` is tied to"""
return self._phases
@property
def names(self):
"""
Unique list of string names identifying this phase. If this :class:`MultiPhaseAnalyzer` is combined with
another, its possible that new names will be generated unique to that :class:`MultiPhaseAnalyzer`, but will
still reference the same phase.
When in doubt, use :func:`MultiPhaseAnalyzer.phases` to get the actual phase objects.
"""
return self._names
@property
def signs(self):
"""
List of signs that are used by the :class:`MultiPhaseAnalyzer` to
"""
return self._signs
[docs] def clear(self):
"""
Clear the individual phases of their observables and estimators for re-computing quantities
"""
for phase in self.phases:
phase.clear()
def _combine_phases(self, other, operator='+'):
"""
Function to combine the phases regardless of operator to reduce code duplication. Creates a new
:class:`MultiPhaseAnalyzer` object based on the combined phases of the other. Accepts either a
:class:`PhaseAnalyzer` or a :class:`MultiPhaseAnalyzer`.
If the names have collision, they are re-named with an extra digit at the end.
Parameters
----------
other : :class:`MultiPhaseAnalyzer` or :class:`PhaseAnalyzer`
operator : sign of the operator connecting the two objects
Returns
-------
output : :class:`MultiPhaseAnalyzer`
New :class:`MultiPhaseAnalyzer` where the phases are the combined list of the individual phases from each
component. Because the memory pointers to the individual phases are the same, changing any
single :class:`PhaseAnalyzer`'s
reference_state objects updates all :class:`MultiPhaseAnalyzer` objects they are tied to
"""
phases = []
names = []
signs = []
# create copies
phases.extend(self.phases)
names.extend(self.names)
signs.extend(self.signs)
if isinstance(other, MultiPhaseAnalyzer):
new_phases = other.phases
new_signs = other.signs
new_names = other.names
final_new_names = []
for name in new_names:
other_names = [n for n in new_names if n != name]
final_new_names.append(utils.generate_phase_name(name, other_names + names))
names.extend(final_new_names)
for new_sign in new_signs:
if (operator == '-' and new_sign == '+') or (operator == '+' and new_sign == '-'):
signs.append('-')
else:
signs.append('+')
signs.extend(new_signs)
phases.extend(new_phases)
elif isinstance(other, PhaseAnalyzer):
names.append(utils.generate_phase_name(other.name, names))
if (operator == '-' and other._sign == '+') or (operator == '+' and other._sign == '-'):
signs.append('-')
else:
signs.append('+')
other._sign = '+' # Recast to positive if negated
phases.append(other)
else:
baseerr = "cannot {} 'MultiPhaseAnalyzer' and '{}' objects"
if operator == '+':
err = baseerr.format('add', type(other))
else:
err = baseerr.format('subtract', type(other))
raise TypeError(err)
phase_pass = {'phases': phases, 'signs': signs, 'names': names}
return MultiPhaseAnalyzer(phase_pass)
def __add__(self, other):
return self._combine_phases(other, operator='+')
def __sub__(self, other):
return self._combine_phases(other, operator='-')
def __neg__(self):
"""
Return a SHALLOW copy of self with negated signs so that the phase objects all still point to the same
objects
"""
new_signs = []
for sign in self._signs:
if sign == '+':
new_signs.append('-')
else:
new_signs.append('+')
# return a *shallow* copy of self with the signs reversed
output = copy.copy(self)
output._signs = new_signs
return output
def __str__(self):
"""Simplified string output"""
header = "MultiPhaseAnalyzer<{}>"
output_string = ""
for phase_name, sign in zip(self.names, self.signs):
if output_string == "" and sign == '-':
output_string += '{}{} '.format(sign, phase_name)
elif output_string == "":
output_string += '{} '.format(phase_name)
else:
output_string += '{} {} '.format(sign, phase_name)
return header.format(output_string)
def __repr__(self):
"""Generate a detailed representation of the MultiPhase"""
header = "MultiPhaseAnalyzer <\n{}>"
output_string = ""
for phase, phase_name, sign in zip(self.phases, self.names, self.signs):
if output_string == "" and sign == '-':
output_string += '{}{} ({})\n'.format(sign, phase_name, phase)
elif output_string == "":
output_string += '{} ({})\n'.format(phase_name, phase)
else:
output_string += ' {} {} ({})\n'.format(sign, phase_name, phase)
return header.format(output_string)
def _compute_observable(self, observable_name):
"""
Helper function to compute arbitrary observable in both phases
Parameters
----------
observable_name : str
Name of the observable as its defined in the ObservablesRegistry
Returns
-------
observable_value
The observable as its combined between all the phases
"""
def prepare_phase_observable(single_phase):
"""Helper function to cast the observable in terms of observable's registry"""
observable = getattr(single_phase, "get_" + observable_name)()
if isinstance(single_phase, MultiPhaseAnalyzer):
if observable_name in self.registry.observables_with_error:
observable_payload = dict()
observable_payload['value'], observable_payload['error'] = observable
else:
observable_payload = observable
else:
raise_registry_error = False
if observable_name in self.registry.observables_with_error:
observable_payload = {}
if observable_name in self.registry.observables_defined_by_phase:
observable_payload['value'], observable_payload['error'] = observable
elif observable_name in self.registry.observables_defined_by_single_state:
observable_payload['value'] = observable[0][single_phase.reference_states[0]]
observable_payload['error'] = observable[1][single_phase.reference_states[0]]
elif observable_name in self.registry.observables_defined_by_two_states:
observable_payload['value'] = observable[0][single_phase.reference_states[0],
single_phase.reference_states[1]]
observable_payload['error'] = observable[1][single_phase.reference_states[0],
single_phase.reference_states[1]]
else:
raise_registry_error = True
else: # No error
if observable_name in self.registry.observables_defined_by_phase:
observable_payload = observable
elif observable_name in self.registry.observables_defined_by_single_state:
observable_payload = observable[single_phase.reference_states[0]]
elif observable_name in self.registry.observables_defined_by_two_states:
observable_payload = observable[single_phase.reference_states[0],
single_phase.reference_states[1]]
else:
raise_registry_error = True
if raise_registry_error:
raise RuntimeError("You have requested an observable that is improperly registered in the "
"ObservablesRegistry!")
return observable_payload
def modify_final_output(passed_output, payload, sign):
if observable_name in self.registry.observables_with_error:
if sign == '+':
passed_output['value'] += payload['value']
else:
passed_output['value'] -= payload['value']
if observable_name in self.registry.observables_with_error_adding_linear:
passed_output['error'] += payload['error']
elif observable_name in self.registry.observables_with_error_adding_quadrature:
passed_output['error'] = (passed_output['error']**2 + payload['error']**2)**0.5
else:
if sign == '+':
passed_output += payload
else:
passed_output -= payload
return passed_output
if observable_name in self.registry.observables_with_error:
final_output = {'value': 0, 'error': 0}
else:
final_output = 0
for phase, phase_sign in zip(self.phases, self.signs):
phase_observable = prepare_phase_observable(phase)
final_output = modify_final_output(final_output, phase_observable, phase_sign)
if observable_name in self.registry.observables_with_error:
# Cast output to tuple
final_output = (final_output['value'], final_output['error'])
return final_output