#!/usr/local/bin/env python
# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================
"""
Analyze
=======
Analysis tools and module for YANK simulations. Provides programmatic and automatic "best practices" integration to
determine free energy and other observables.
Fully extensible to support new samplers and observables.
"""
# =============================================================================================
# Analyze datafiles produced by YANK.
# =============================================================================================
import os
import os.path
import abc
import copy
import yaml
import numpy as np
import openmmtools as mmtools
from .repex import Reporter
from pymbar import MBAR  # multi-state Bennett acceptance ratio
from pymbar import timeseries  # for statistical inefficiency analysis
import mdtraj
import simtk.unit as units
import logging
logger = logging.getLogger(__name__)
ABC = abc.ABCMeta('ABC', (object,), {})  # compatible with Python 2 *and* 3
# =============================================================================================
# PARAMETERS
# =============================================================================================
kB = units.BOLTZMANN_CONSTANT_kB * units.AVOGADRO_CONSTANT_NA
# =============================================================================================
# MODULE FUNCTIONS
# =============================================================================================
[docs]def generate_phase_name(current_name, name_list):
    """
    Provide a regular way to generate unique human-readable names from base names.
    Given a base name and a list of existing names, a number will be appended to the base name until a unique string
    is generated.
    Parameters
    ----------
    current_name : string
        The base name you wish to ensure is unique. Numbers will be appended to this string until a unique string
        not in the name_list is provided
    name_list : iterable of strings
        The current_name, and its modifiers, are compared against this list until a unique string is found
    Returns
    -------
    name : string
        Unique string derived from the current_name that is not in name_list.
        If the parameter current_name is not already in the name_list, then current_name is returned unmodified.
    """
    base_name = 'phase{}'
    counter = 0
    if current_name is None:
        name = base_name.format(counter)
        while name in name_list:
            counter += 1
            name = base_name.format(counter)
    elif current_name in name_list:
        name = current_name + str(counter)
        while name in name_list:
            counter += 1
            name = current_name + str(counter)
    else:
        name = current_name
    return name 
[docs]def get_analyzer(file_base_path):
    """
    Utility function to convert storage file to a Reporter and Analyzer by reading the data on file
    For now this is mostly placeholder functions since there is only the implemented :class:`ReplicaExchangeAnalyzer`,
    but creates the API for the user to work with.
    Parameters
    ----------
    file_base_path : string
        Complete path to the storage file with filename and extension.
    Returns
    -------
    analyzer : instance of implemented :class:`YankPhaseAnalyzer`
        Analyzer for the specific phase.
    """
    # Eventually extend this to get more reporters, but for now simple placeholder
    reporter = Reporter(file_base_path, open_mode='r')
    """
    storage = infer_storage_format_from_extension('complex.nc')  # This is always going to be nc for now.
    metadata = storage.metadata
    sampler_class = metadata['sampler_full_name']
    module_name, cls_name = sampler_full_name.rsplit('.', 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, cls_name)
    reporter = cls.create_reporter('complex.nc')
    """
    # Eventually change this to auto-detect simulation from reporter:
    if True:
        analyzer = ReplicaExchangeAnalyzer(reporter)
    else:
        raise RuntimeError("Cannot automatically determine analyzer for Reporter: {}".format(reporter))
    return analyzer 
[docs]def get_decorrelation_time(timeseries_to_analyze):
    """
    Compute the decorrelation times given a timeseries.
    See the ``pymbar.timeseries.statisticalInefficiency`` for full documentation
    """
    return timeseries.statisticalInefficiency(timeseries_to_analyze) 
[docs]def get_equilibration_data(timeseries_to_analyze):
    """
    Compute equilibration method given a timeseries
    See the ``pymbar.timeseries.detectEquilibration`` function for full documentation
    """
    [n_equilibration, g_t, n_effective_max] = timeseries.detectEquilibration(timeseries_to_analyze)
    return n_equilibration, g_t, n_effective_max 
[docs]def get_equilibration_data_per_sample(timeseries_to_analyze, fast=True, nskip=1):
    """
    Compute the correlation time and n_effective per sample.
    This is exactly what ``pymbar.timeseries.detectEquilibration`` does, but returns the per sample data
    See the ``pymbar.timeseries.detectEquilibration`` function for full documentation
    """
    A_t = timeseries_to_analyze
    T = A_t.size
    g_t = np.ones([T - 1], np.float32)
    Neff_t = np.ones([T - 1], np.float32)
    for t in range(0, T - 1, nskip):
        try:
            g_t[t] = timeseries.statisticalInefficiency(A_t[t:T], fast=fast)
        except:
            g_t[t] = (T - t + 1)
        Neff_t[t] = (T - t + 1) / g_t[t]
    return g_t, Neff_t 
[docs]def remove_unequilibrated_data(data, number_equilibrated, axis):
    """
    Remove the number_equilibrated samples from a dataset
    Discards number_equilibrated number of indices from given axis
    Parameters
    ----------
    data : np.array-like of any dimension length
        This is the data which will be paired down
    number_equilibrated : int
        Number of indices that will be removed from the given axis, i.e. axis will be shorter by number_equilibrated
    axis : int
        Axis index along which to remove samples from. This supports negative indexing as well
    Returns
    -------
    equilibrated_data : ndarray
        Data with the number_equilibrated number of indices removed from the beginning along axis
    """
    cast_data = np.asarray(data)
    # Define the slice along an arbitrary dimension
    slc = [slice(None)] * len(cast_data.shape)
    # Set the dimension we are truncating
    slc[axis] = slice(number_equilibrated, None)
    # Slice
    equilibrated_data = cast_data[slc]
    return equilibrated_data 
[docs]def subsample_data_along_axis(data, subsample_rate, axis):
    """
    Generate a decorrelated version of a given input data and subsample_rate along a single axis.
    Parameters
    ----------
    data : np.array-like of any dimension length
    subsample_rate : float or int
        Rate at which to draw samples. A sample is considered decorrelated after every ceil(subsample_rate) of
        indices along data and the specified axis
    axis : int
        axis along which to apply the subsampling
    Returns
    -------
    subsampled_data : ndarray of same number of dimensions as data
        Data will be subsampled along the given axis
    """
    # TODO: find a name for the function that clarifies that decorrelation
    # TODO:             is determined exclusively by subsample_rate?
    cast_data = np.asarray(data)
    data_shape = cast_data.shape
    # Since we already have g, we can just pass any appropriate shape to the subsample function
    indices = timeseries.subsampleCorrelatedData(np.zeros(data_shape[axis]), g=subsample_rate)
    subsampled_data = np.take(cast_data, indices, axis=axis)
    return subsampled_data 
# =============================================================================================
# MODULE CLASSES
# =============================================================================================
class _ObservablesRegistry(object):
    """
    Registry of computable observables.
    This is a hidden class accessed by the :class:`YankPhaseAnalyzer` and :class:`MultiPhaseAnalyzer` objects to check
    which observables can be computed, and then provide a regular categorization of them. This is a static registry.
    To define your own methods:
    1) Choose a unique observable name.
    2) Categorize the observable in one of the following ways by adding to the list in the "observables_X" method:
     2a) "defined_by_phase": Depends on the Phase as a whole (state independent)
     2b) "defined_by_single_state": Computed entirely from one state, e.g. Radius of Gyration
     2c) "defined_by_two_states": Property is relative to some reference state, such as Free Energy Difference
    3) Optionally categorize the error category calculation in the "observables_with_error_adding_Y" methods
       If not placed in an error category, the observable will be assumed not to carry error
       Examples: A, B, C are the observable in 3 phases, eA, eB, eC are the error of the observable in each phase
     3a) "linear": Error between phases adds linearly.
        If C = A + B, eC = eA + eB
     3b) "quadrature": Error between phases adds in the square.
        If C = A + B, eC = sqrt(eA^2 + eB^2)
    4) Finally, to add this observable to the phase, implement a "get_{method name}" method to the subclass of
       :class:`YankPhaseAnalyzer`. Any :class:`MultiPhaseAnalyzer` composed of this phase will automatically have the
       "get_{method name}" if all other phases in the :class:`MultiPhaseAnalyzer` have the same method.
    """
    ########################
    # Define the observables
    ########################
    @staticmethod
    def observables():
        """
        Set of observables which are derived from the subsets below
        """
        observables = set()
        for subset in (_ObservablesRegistry.observables_defined_by_two_states(),
                       _ObservablesRegistry.observables_defined_by_single_state(),
                       _ObservablesRegistry.observables_defined_by_phase()):
            observables = observables.union(set(subset))
        return tuple(observables)
    # ------------------------------------------------
    # Exclusive Observable categories
    # The intersection of these should be the null set
    # ------------------------------------------------
    @staticmethod
    def observables_defined_by_two_states():
        """
        Observables that require an i and a j state to define the observable accurately between phases
        """
        return 'entropy', 'enthalpy', 'free_energy'
    @staticmethod
    def observables_defined_by_single_state():
        """
        Defined observables which are fully defined by a single state, and not by multiple states such as differences
        """
        return tuple()
    @staticmethod
    def observables_defined_by_phase():
        """
        Observables which are defined by the phase as a whole, and not defined by any 1 or more states
        e.g. Standard State Correction
        """
        return 'standard_state_correction',
    ##########################################
    # Define the observables which carry error
    # This should be a subset of observables()
    ##########################################
    @staticmethod
    def observables_with_error():
        """Determine which observables have error by inspecting the the error subsets"""
        observables = set()
        for subset in (_ObservablesRegistry.observables_with_error_adding_quadrature(),
                       _ObservablesRegistry.observables_with_error_adding_linear()):
            observables = observables.union(set(subset))
        return tuple(observables)
    # ------------------------------------------------
    # Exclusive Error categories
    # The intersection of these should be the null set
    # ------------------------------------------------
    @staticmethod
    def observables_with_error_adding_quadrature():
        """Observable C = A + B, Error eC = sqrt(eA**2 + eB**2)"""
        return 'entropy', 'enthalpy', 'free_energy'
    @staticmethod
    def observables_with_error_adding_linear():
        """Observable C = A + B, Error eC = eA + eB"""
        return tuple()
# ---------------------------------------------------------------------------------------------
# Phase Analyzers
# ---------------------------------------------------------------------------------------------
[docs]class YankPhaseAnalyzer(ABC):
    """
    Analyzer for a single phase of a YANK simulation.
    Uses the reporter from the simulation to determine the location
    of all variables.
    To compute a specific observable in an implementation of this class, add it to the ObservableRegistry and then
    implement a ``get_X`` where ``X`` is the name of the observable you want to compute. See the ObservablesRegistry for
    information about formatting the observables.
    Analyzer works in units of kT unless specifically stated otherwise. To convert back to a unit set, just multiply by
    the .kT property.
    Parameters
    ----------
    reporter : Reporter instance
        Reporter from Repex which ties to the simulation data on disk.
    name : str, Optional
        Unique name you want to assign this phase, this is the name that will appear in :class:`MultiPhaseAnalyzer`'s.
        If not set, it will be given the arbitrary name "phase#" where # is an integer, chosen in order that it is
        assigned to the :class:`MultiPhaseAnalyzer`.
    reference_states : tuple of ints, length 2, Optional, Default: (0,-1)
        Integers ``i`` and ``j`` of the state that is used for reference in observables, "O". These values are only used
        when reporting single numbers or combining observables through :class:`MultiPhaseAnalyzer` (since the number of
        states between phases can be different). Calls to functions such as ``get_free_energy`` in a single Phase
        results in the O being returned for all states.
            For O completely defined by the state itself (i.e. no differences between states, e.g. Temperature),
            only O[i] is used
            For O where differences between states are required (e.g. Free Energy): O[i,j] = O[j] - O[i]
            For O defined by the phase as a whole, the reference states are not needed.
    analysis_kwargs : None or dict, optional
        Dictionary of extra keyword arguments to pass into the analysis tool, typically MBAR.
        For instance, the initial guess of relative free energies to give to MBAR would be something like:
        ``{'initial_f_k':[0,1,2,3]}``
    Attributes
    ----------
    name
    observables
    mbar
    reference_states
    kT
    reporter
    """
    def __init__(self, reporter, name=None, reference_states=(0, -1), analysis_kwargs=None):
        """
        The reporter provides the hook into how to read the data, all other options control where differences are
        measured from and how each phase interfaces with other phases.
        """
        if not reporter.is_open():
            reporter.open(mode='r')
        self._reporter = reporter
        observables = []
        # Auto-determine the computable observables by inspection of non-flagged methods
        # We determine valid observables by negation instead of just having each child implement the method to enforce
        # uniform function naming conventions.
        self._computed_observables = {}  # Cache of observables so the phase can be retrieved once computed
        for observable in _ObservablesRegistry.observables():
            if hasattr(self, "get_" + observable):
                observables.append(observable)
                self._computed_observables[observable] = None
        # Cast observables to an immutable
        self._observables = tuple(observables)
        # Internal properties
        self._name = name
        # Start as default sign +, handle all sign conversion at peration time
        self._sign = '+'
        self._equilibration_data = None  # Internal tracker so the functions can get this data without recalculating it
        # External properties
        self._reference_states = None  # initialize the cache object
        self.reference_states = reference_states
        self._mbar = None
        self._kT = None
        if type(analysis_kwargs) not in [type(None), dict]:
            raise ValueError('analysis_kwargs must be either None or a dictionary')
        self._extra_analysis_kwargs = analysis_kwargs if (analysis_kwargs is not None) else dict()
    @property
    def name(self):
        """User-readable string name of the phase"""
        return self._name
    @name.setter
    def name(self, value):
        self._name = value
    @property
    def observables(self):
        """
        List of observables that the instanced analyzer can compute/fetch.
        This list is automatically compiled upon class initialization based on the functions implemented in the subclass
        """
        return self._observables
    @property
    def mbar(self):
        """MBAR object tied to this phase"""
        if self._mbar is None:
            self._create_mbar_from_scratch()
        return self._mbar
    @property
    def reference_states(self):
        """Tuple of reference states ``i`` and ``j`` for :class:`MultiPhaseAnalyzer` instances"""
        return self._reference_states
    @reference_states.setter
    def reference_states(self, value):
        """Provide a way to re-assign the ``i, j`` states in a protected way"""
        i, j = value[0], value[1]
        if type(i) is not int or type(j) is not int:
            raise ValueError("reference_states must be a length 2 iterable of ints")
        self._reference_states = (i, j)
    @property
    def kT(self):
        """
        Quantity of boltzmann constant times temperature of the phase in units of energy per mol
        Allows conversion between dimensionless energy and unit bearing energy
        """
        if self._kT is None:
            thermodynamic_states, _ = self._reporter.read_thermodynamic_states()
            temperature = thermodynamic_states[0].temperature
            self._kT = kB * temperature
        return self._kT
    @property
    def reporter(self):
        """Sampler Reporter tied to this object."""
        return self._reporter
    @reporter.setter
    def reporter(self, value):
        """Make sure users cannot overwrite the reporter."""
        raise ValueError("You cannot re-assign the reporter for this analyzer!")
    # Abstract methods
    @abc.abstractmethod
[docs]    def analyze_phase(self, *args, **kwargs):
        """
        Auto-analysis function for the phase
        Function which broadly handles "auto-analysis" for those that do not wish to call all the methods on their own.
        This should be have like the old "analyze" function from versions of YANK pre-1.0.
        Returns a dictionary of analysis objects
        """
        raise NotImplementedError() 
    @abc.abstractmethod
    def _create_mbar_from_scratch(self):
        """
        This method should automatically do everything needed to make the MBAR object from file. It should make all
        the assumptions needed to make the MBAR object.  Typically alot of these functions will be needed for the
        :func:`analyze_phase` function.
        Should call the :func:`_prepare_mbar_input_data` to get the data ready for
        Returns nothing, but the self.mbar object should be set after this.
        """
        raise NotImplementedError()
    @abc.abstractmethod
    def _prepare_mbar_input_data(self, *args, **kwargs):
        """
        Prepare a set of data for MBAR, because each analyzer may need to do something else to prepare for MBAR, it
        should have its own function to do that with.
        Parameters
        ----------
        args : arguments needed to generate the appropriate Returns
        kwargs : keyword arguments needed to generate the appropriate Returns
        Returns
        -------
        energy_matrix : energy matrix of shape (K,L,N), indexed by k,l,n
            K is the total number of sampled states
            L is the total states we want MBAR to analyze
            N is the total number of samples
            The kth sample was drawn from state k at iteration n,
                the nth configuration of kth state is evaluated in thermodynamic state l
        samples_per_state : 1-D iterable of shape L
            The total number of samples drawn from each lth state
        """
        raise NotImplementedError()
    @abc.abstractmethod
[docs]    def get_states_energies(self):
        """
        Extract the deconvoluted energies from a phase.
        Energies from this are NOT decorrelated.
        Returns
        -------
        sampled_energy_matrix : numpy.ndarray of shape K,K,N
            Deconvoluted energy of sampled states evaluated at other sampled states.
            Has shape (K,K,N) = (number of sampled states, number of sampled states, number of iterations)
            Indexed by [k,l,n] where an energy drawn from sampled state [k] is evaluated in sampled state [l] at
            iteration [n]
        unsampled_energy_matrix : numpy.ndarray of shape K,L,N
            Has shape (K, L, N) = (number of sampled states, number of UN-sampled states, number of iterations)
            Indexed by [k,l,n]
            where an energy drawn from sampled state [k] is evaluated in un-sampled state [l] at iteration [n]
        """
        raise NotImplementedError() 
    # This SHOULD be an abstract static method since its related to the analyzer, but could handle any input data
    # Until we drop Python 2.7, we have to keep this method
    # @abc.abstractmethod
    @staticmethod
[docs]    def get_timeseries(passed_timeseries):
        """
        Generate the timeseries that is generated for this phase
        Returns
        -------
        generated_timeseries : 1-D iterable
            timeseries which can be fed into get_decorrelation_time to get the decorrelation
        """
        raise NotImplementedError("This class has not implemented this function") 
    # Private Class Methods
    def _create_mbar(self, energy_matrix, samples_per_state):
        """
        Initialize MBAR for Free Energy and Enthalpy estimates, this may take a while.
        This function is helpful for those who want to create a slightly different mbar object with different
        parameters.
        This function is hidden from the user unless they really, really need to create their own mbar object
        Parameters
        ----------
        energy_matrix : array of numpy.float64, optional, default=None
           Reduced potential energies of the replicas; if None, will be extracted from the ncfile
        samples_per_state : array of ints, optional, default=None
           Number of samples drawn from each kth state; if None, will be extracted from the ncfile
        """
        # Delete observables cache since we are now resetting the estimator
        for observable in self.observables:
            self._computed_observables[observable] = None
        # Initialize MBAR (computing free energy estimates, which may take a while)
        logger.info("Computing free energy differences...")
        mbar = MBAR(energy_matrix, samples_per_state, **self._extra_analysis_kwargs)
        self._mbar = mbar
    def _combine_phases(self, other, operator='+'):
        """
        Workhorse function when creating a :class:`MultiPhaseAnalyzer` object by combining single
        :class:`YankPhaseAnalyzers`
        """
        phases = [self]
        names = []
        signs = [self._sign]
        # Reset self._sign
        self._sign = '+'
        if self.name is None:
            names.append(generate_phase_name(self, []))
        else:
            names.append(self.name)
        if isinstance(other, MultiPhaseAnalyzer):
            new_phases = other.phases
            new_signs = other.signs
            new_names = other.names
            final_new_names = []
            for name in new_names:
                other_names = [n for n in new_names if n != name]
                final_new_names.append(generate_phase_name(name, other_names + names))
            names.extend(final_new_names)
            for new_sign in new_signs:
                if operator != '+' and new_sign == '+':
                    signs.append('-')
                else:
                    signs.append('+')
            phases.extend(new_phases)
        elif isinstance(other, YankPhaseAnalyzer):
            names.append(generate_phase_name(other.name, names))
            if operator != '+' and other._sign == '+':
                signs.append('-')
            else:
                signs.append('+')
            # Reset the other's sign if it got set to negative
            other._sign = '+'
            phases.append(other)
        else:
            baseerr = "cannot {} 'YankPhaseAnalyzer' and '{}' objects"
            if operator == '+':
                err = baseerr.format('add', type(other))
            else:
                err = baseerr.format('subtract', type(other))
            raise TypeError(err)
        phase_pass = {'phases': phases, 'signs': signs, 'names': names}
        return MultiPhaseAnalyzer(phase_pass)
    def __add__(self, other):
        return self._combine_phases(other, operator='+')
    def __sub__(self, other):
        return self._combine_phases(other, operator='-')
    def __neg__(self):
        """Internally handle the internal sign"""
        if self._sign == '+':
            self._sign = '-'
        else:
            self._sign = '+'
        return self 
[docs]class ReplicaExchangeAnalyzer(YankPhaseAnalyzer):
    """
    The ReplicaExchangeAnalyzer is the analyzer for a simulation generated from a Replica Exchange sampler simulation,
    implemented as an instance of the :class:`YankPhaseAnalyzer`.
    See Also
    --------
    YankPhaseAnalyzer
    """
[docs]    def generate_mixing_statistics(self, number_equilibrated=None):
        """
        Generate the Transition state matrix and sorted eigenvalues
        Parameters
        ----------
        number_equilibrated : int, optional, default=None
           If specified, only samples number_equilibrated:end will be used in analysis
           If not specified, automatically retrieves the number from equilibration data or generates it from the
           internal energy.
        Returns
        -------
        mixing_stats : np.array of shape [nstates, nstates]
            Transition matrix estimate
        mu : np.array
            Eigenvalues of the Transition matrix sorted in descending order
        """
        # Read data from disk
        if number_equilibrated is None:
            if self._equilibration_data is None:
                self._get_equilibration_data_auto()
            number_equilibrated, _, _ = self._equilibration_data
        states = self._reporter.read_replica_thermodynamic_states()
        n_iterations, n_states = states.shape
        n_ij = np.zeros([n_states, n_states], np.int64)
        # Compute empirical transition count matrix.
        for iteration in range(number_equilibrated, n_iterations - 1):
            for i_replica in range(n_states):
                i_state = states[iteration, i_replica]
                j_state = states[iteration + 1, i_replica]
                n_ij[i_state, j_state] += 1
        # Compute transition matrix estimate.
        # TODO: Replace with maximum likelihood reversible count estimator from msmbuilder or pyemma.
        t_ij = np.zeros([n_states, n_states], np.float64)
        for i_state in range(n_states):
            # Cast to float to ensure we dont get integer division
            denominator = float((n_ij[i_state, :].sum() + n_ij[:, i_state].sum()))
            if denominator > 0:
                for j_state in range(n_states):
                    t_ij[i_state, j_state] = (n_ij[i_state, j_state] + n_ij[j_state, i_state]) / denominator
            else:
                t_ij[i_state, i_state] = 1.0
        # Estimate eigenvalues
        mu = np.linalg.eigvals(t_ij)
        mu = -np.sort(-mu)  # Sort in descending order
        return t_ij, mu 
[docs]    def show_mixing_statistics(self, cutoff=0.05, number_equilibrated=None):
        """
        Print summary of mixing statistics. Passes information off to generate_mixing_statistics then prints it out to
        the logger
        Parameters
        ----------
        cutoff : float, optional, default=0.05
           Only transition probabilities above 'cutoff' will be printed
        number_equilibrated : int, optional, default=None
           If specified, only samples number_equilibrated:end will be used in analysis
           If not specified, it uses the internally held statistics best
        """
        Tij, mu = self.generate_mixing_statistics(number_equilibrated=number_equilibrated)
        # Print observed transition probabilities.
        nstates = Tij.shape[1]
        logger.info("Cumulative symmetrized state mixing transition matrix:")
        str_row = "{:6s}".format("")
        for jstate in range(nstates):
            str_row += "{:6d}".format(jstate)
        logger.info(str_row)
        for istate in range(nstates):
            str_row = ""
            str_row += "{:-6d}".format(istate)
            for jstate in range(nstates):
                P = Tij[istate, jstate]
                if P >= cutoff:
                    str_row += "{:6.3f}".format(P)
                else:
                    str_row += "{:6s}".format("")
            logger.info(str_row)
        # Estimate second eigenvalue and equilibration time.
        if mu[1] >= 1:
            logger.info("Perron eigenvalue is unity; Markov chain is decomposable.")
        else:
            logger.info("Perron eigenvalue is {0:9.5f}; state equilibration timescale is ~ {1:.1f} iterations".format(
                mu[1], 1.0 / (1.0 - mu[1]))
            ) 
[docs]    def get_states_energies(self):
        """
        Extract and decorrelate energies from the ncfile to gather energies common data for other functions
        Returns
        -------
        energy_matrix : ndarray of shape (K,K,N)
            Potential energy matrix of the sampled states
            Energy is from each sample n, drawn from state (first k), and evaluated at every sampled state (second k)
        unsampled_energy_matrix : ndarray of shape (K,L,N)
            Potential energy matrix of the unsampled states
            Energy from each sample n, drawn from sampled state k, evaluated at unsampled state l
            If no unsampled states were drawn, this will be shape (K,0,N)
        """
        logger.info("Reading energies...")
        energy_thermodynamic_states, energy_unsampled_states = self._reporter.read_energies()
        n_iterations, _, n_states = energy_thermodynamic_states.shape
        _, _, n_unsampled_states = energy_unsampled_states.shape
        energy_matrix_replica = np.zeros([n_states, n_states, n_iterations], np.float64)
        unsampled_energy_matrix_replica = np.zeros([n_states, n_unsampled_states, n_iterations], np.float64)
        for n in range(n_iterations):
            energy_matrix_replica[:, :, n] = energy_thermodynamic_states[n, :, :]
            unsampled_energy_matrix_replica[:, :, n] = energy_unsampled_states[n, :, :]
        logger.info("Done.")
        logger.info("Deconvoluting replicas...")
        energy_matrix = np.zeros([n_states, n_states, n_iterations], np.float64)
        unsampled_energy_matrix = np.zeros([n_states, n_unsampled_states, n_iterations], np.float64)
        for iteration in range(n_iterations):
            state_indices = self._reporter.read_replica_thermodynamic_states(iteration)
            energy_matrix[state_indices, :, iteration] = energy_matrix_replica[:, :, iteration]
            unsampled_energy_matrix[state_indices, :, iteration] = unsampled_energy_matrix_replica[:, :, iteration]
        logger.info("Done.")
        return energy_matrix, unsampled_energy_matrix 
    @staticmethod
[docs]    def get_timeseries(passed_timeseries):
        """
        Compute the timeseries of a simulation from the Replica Exchange simulation. This is the sum of energies
        for each sample from the state it was drawn from.
        Parameters
        ----------
        passed_timeseries : ndarray of shape (K,L,N), indexed by k,l,n
            K is the total number of sampled states
            L is the total states we want MBAR to analyze
            N is the total number of samples
            The kth sample was drawn from state k at iteration n, the nth configuration of kth state is evaluated in
            thermodynamic state l
        Returns
        -------
        u_n : ndarray of shape (N,)
            Timeseries to compute decorrelation and equilibration data from.
        """
        niterations = passed_timeseries.shape[-1]
        u_n = np.zeros([niterations], np.float64)
        # Compute total negative log probability over all iterations.
        for iteration in range(niterations):
            u_n[iteration] = np.sum(np.diagonal(passed_timeseries[:, :, iteration]))
        return u_n 
    def _prepare_mbar_input_data(self, sampled_energy_matrix, unsampled_energy_matrix):
        """Convert the sampled and unsampled energy matrices into MBAR ready data"""
        nstates, _, niterations = sampled_energy_matrix.shape
        _, nunsampled, _ = unsampled_energy_matrix.shape
        # Subsample data to obtain uncorrelated samples
        N_k = np.zeros(nstates, np.int32)
        N = niterations  # number of uncorrelated samples
        N_k[:] = N
        mbar_ready_energy_matrix = sampled_energy_matrix
        if nunsampled > 0:
            fully_interacting_u_ln = unsampled_energy_matrix[:, 0, :]
            noninteracting_u_ln = unsampled_energy_matrix[:, 1, :]
            # Augment u_kln to accept the new state
            new_energy_matrix = np.zeros([nstates + 2, nstates + 2, N], np.float64)
            N_k_new = np.zeros(nstates + 2, np.int32)
            # Insert energies
            new_energy_matrix[1:-1, 0, :] = fully_interacting_u_ln
            new_energy_matrix[1:-1, -1, :] = noninteracting_u_ln
            # Fill in other energies
            new_energy_matrix[1:-1, 1:-1, :] = sampled_energy_matrix
            N_k_new[1:-1] = N_k
            # Notify users
            logger.info("Found expanded cutoff states in the energies!")
            logger.info("Free energies will be reported relative to them instead!")
            # Reset values, last step in case something went wrong so we dont overwrite u_kln on accident
            mbar_ready_energy_matrix = new_energy_matrix
            N_k = N_k_new
        return mbar_ready_energy_matrix, N_k
    def _compute_free_energy(self):
        """
        Estimate free energies of all alchemical states.
        """
        # Create MBAR object if not provided
        if self._mbar is None:
            self._create_mbar_from_scratch()
        nstates = self.mbar.N_k.size
        # Get matrix of dimensionless free energy differences and uncertainty estimate.
        logger.info("Computing covariance matrix...")
        try:
            # pymbar 2
            (Deltaf_ij, dDeltaf_ij) = self.mbar.getFreeEnergyDifferences()
        except ValueError:
            # pymbar 3
            (Deltaf_ij, dDeltaf_ij, theta_ij) = self.mbar.getFreeEnergyDifferences()
        # Matrix of free energy differences
        logger.info("Deltaf_ij:")
        for i in range(nstates):
            str_row = ""
            for j in range(nstates):
                str_row += "{:8.3f}".format(Deltaf_ij[i, j])
            logger.info(str_row)
        # Matrix of uncertainties in free energy difference (expectations standard
        # deviations of the estimator about the true free energy)
        logger.info("dDeltaf_ij:")
        for i in range(nstates):
            str_row = ""
            for j in range(nstates):
                str_row += "{:8.3f}".format(dDeltaf_ij[i, j])
            logger.info(str_row)
        # Return free energy differences and an estimate of the covariance.
        free_energy_dict = {'value': Deltaf_ij, 'error': dDeltaf_ij}
        self._computed_observables['free_energy'] = free_energy_dict
[docs]    def get_free_energy(self):
        """
        Compute the free energy and error in free energy from the MBAR object
        Output shape changes based on if there are unsampled states detected in the sampler
        Returns
        -------
        DeltaF_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
            Difference in free energy from each state relative to each other state
        dDeltaF_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
            Error in the difference in free energy from each state relative to each other state
        """
        if self._computed_observables['free_energy'] is None:
            self._compute_free_energy()
        free_energy_dict = self._computed_observables['free_energy']
        return free_energy_dict['value'], free_energy_dict['error'] 
    def _compute_enthalpy_and_entropy(self):
        """Function to compute the cached values of enthalpy and entropy"""
        if self._mbar is None:
            self._create_mbar_from_scratch()
        (f_k, df_k, H_k, dH_k, S_k, dS_k) = self.mbar.computeEntropyAndEnthalpy()
        enthalpy = {'value': H_k, 'error': dH_k}
        entropy = {'value': S_k, 'error': dS_k}
        self._computed_observables['enthalpy'] = enthalpy
        self._computed_observables['entropy'] = entropy
[docs]    def get_enthalpy(self):
        """
        Compute the difference in enthalpy and error in that estimate from the MBAR object
        Output shape changes based on if there are unsampled states detected in the sampler
        Returns
        -------
        DeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
            Difference in enthalpy from each state relative to each other state
        dDeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
            Error in the difference in enthalpy from each state relative to each other state
        """
        if self._computed_observables['enthalpy'] is None:
            self._compute_enthalpy_and_entropy()
        enthalpy_dict = self._computed_observables['enthalpy']
        return enthalpy_dict['value'], enthalpy_dict['error'] 
[docs]    def get_entropy(self):
        """
        Compute the difference in entropy and error in that estimate from the MBAR object
        Output shape changes based on if there are unsampled states detected in the sampler
        Returns
        -------
        DeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
            Difference in enthalpy from each state relative to each other state
        dDeltaH_ij : ndarray of floats, shape (K,K) or (K+2, K+2)
            Error in the difference in enthalpy from each state relative to each other state
        """
        if self._computed_observables['entropy'] is None:
            self._compute_enthalpy_and_entropy()
        entropy_dict = self._computed_observables['entropy']
        return entropy_dict['value'], entropy_dict['error'] 
[docs]    def get_standard_state_correction(self):
        """
        Compute the standard state correction free energy associated with the Phase.
        This usually is just a stored variable, but it may need other calculations.
        Returns
        -------
        standard_state_correction : float
            Free energy contribution from the standard_state_correction
        """
        if self._computed_observables['standard_state_correction'] is None:
            ssc = self._reporter.read_dict('metadata')['standard_state_correction']
            self._computed_observables['standard_state_correction'] = ssc
        return self._computed_observables['standard_state_correction'] 
    def _get_equilibration_data_auto(self, input_data=None):
        """
        Automatically generate the equilibration data from best practices, part of the :func:`_create_mbar_from_scratch`
        routine.
        Parameters
        ----------
        input_data : np.ndarray-like, Optional, Default: None
            Optionally provide the data to look at. If not provided, uses energies from :func:`extract_energies()`
        Returns nothing, but sets self._equilibration_data
        """
        if input_data is None:
            input_data, _ = self.get_states_energies()
        u_n = self.get_timeseries(input_data)
        # Discard equilibration samples.
        # TODO: if we include u_n[0] (the energy right after minimization) in the equilibration detection,
        # TODO:         then number_equilibrated is 0. Find a better way than just discarding first frame.
        self._equilibration_data = get_equilibration_data(u_n[1:])
    def _create_mbar_from_scratch(self):
        u_kln, unsampled_u_kln = self.get_states_energies()
        self._get_equilibration_data_auto(input_data=u_kln)
        number_equilibrated, g_t, Neff_max = self._equilibration_data
        u_kln = remove_unequilibrated_data(u_kln, number_equilibrated, -1)
        unsampled_u_kln = remove_unequilibrated_data(unsampled_u_kln, number_equilibrated, -1)
        # decorrelate_data subsample the energies only based on g_t so both ends up with same indices.
        u_kln = subsample_data_along_axis(u_kln, g_t, -1)
        unsampled_u_kln = subsample_data_along_axis(unsampled_u_kln, g_t, -1)
        mbar_ukln, mbar_N_k = self._prepare_mbar_input_data(u_kln, unsampled_u_kln)
        self._create_mbar(mbar_ukln, mbar_N_k)
    def analyze_phase(self, cutoff=0.05):
        if self._mbar is None:
            self._create_mbar_from_scratch()
        number_equilibrated, g_t, _ = self._equilibration_data
        self.show_mixing_statistics(cutoff=cutoff, number_equilibrated=number_equilibrated)
        data = {}
        # Accumulate free energy differences
        Deltaf_ij, dDeltaf_ij = self.get_free_energy()
        DeltaH_ij, dDeltaH_ij = self.get_enthalpy()
        data['DeltaF'] = Deltaf_ij[self.reference_states[0], self.reference_states[1]]
        data['dDeltaF'] = dDeltaf_ij[self.reference_states[0], self.reference_states[1]]
        data['DeltaH'] = DeltaH_ij[self.reference_states[0], self.reference_states[1]]
        data['dDeltaH'] = dDeltaH_ij[self.reference_states[0], self.reference_states[1]]
        data['DeltaF_standard_state_correction'] = self.get_standard_state_correction()
        return data 
# https://choderalab.slack.com/files/levi.naden/F4G6L9X8S/quick_diagram.png
[docs]class MultiPhaseAnalyzer(object):
    """
    Multiple Phase Analyzer creator, not to be directly called itself, but instead called by adding or subtracting
    different implemented :class:`YankPhaseAnalyzer` or other :class:`MultiPhaseAnalyzers`'s. The individual Phases of
    the :class:`MultiPhaseAnalyzer` are only references to existing Phase objects, not copies. All
    :class:`YankPhaseAnalyzer` and :class:`MultiPhaseAnalyzer` classes support ``+`` and ``-`` operations.
    The observables of this phase are determined through inspection of all the passed in phases and only observables
    which are shared can be computed. For example:
        ``PhaseA`` has ``.get_free_energy`` and ``.get_entropy``
        ``PhaseB`` has ``.get_free_energy`` and ``.get_enthalpy``,
        ``PhaseAB = PhaseA + PhaseB`` will only have a ``.get_free_energy`` method
    Because each Phase may have a different number of states, the ``reference_states`` property of each phase
    determines which states from each phase to read the data from.
    For observables defined by two states, the i'th and j'th reference states are used:
        If we define ``PhaseAB = PhaseA - PhaseB``
        Then ``PhaseAB.get_free_energy()`` is roughly equivalent to doing the following:
            ``A_i, A_j = PhaseA.reference_states``
            ``B_i, B_j = PhaseB.reference_states``
            ``PhaseA.get_free_energy()[A_i, A_j] - PhaseB.get_free_energy()[B_i, B_j]``
        The above is not exact since get_free_energy returns an error estimate as well
    For observables defined by a single state, only the i'th reference state is used
        Given ``PhaseAB = PhaseA + PhaseB``, ``PhaseAB.get_temperature()`` is equivalent to:
            ``A_i = PhaseA.reference_states[0]``
            ``B_i = PhaseB.reference_states[0]``
            ``PhaseA.get_temperature()[A_i] + PhaseB.get_temperature()[B_i]``
    For observables defined entirely by the phase, no reference states are needed.
        Given ``PhaseAB = PhaseA + PhaseB``, ``PhaseAB.get_standard_state_correction()`` gives:
            ``PhaseA.get_standard_state_correction() + PhaseB.get_standard_state_correction()``
    This class is public to see its API.
    Parameters
    ----------
    phases : dict
        has keys "phases", "names", and "signs"
    Attributes
    ----------
    observables
    phases
    names
    signs
    """
    def __init__(self, phases):
        """
        Create the compound phase which is any combination of phases to generate a new MultiPhaseAnalyzer.
        """
        # Determine
        observables = []
        for observable in _ObservablesRegistry.observables():
            shared_observable = True
            for phase in phases['phases']:
                if observable not in phase.observables:
                    shared_observable = False
                    break
            if shared_observable:
                observables.append(observable)
        if len(observables) == 0:
            raise RuntimeError("There are no shared computable observable between the phases, combining them will do "
                               "nothing.")
        self._observables = tuple(observables)
        self._phases = phases['phases']
        self._names = phases['names']
        self._signs = phases['signs']
        # Set the methods shared between both objects
        for observable in self.observables:
            setattr(self, "get_" + observable, self._spool_function(observable))
    def _spool_function(self, observable):
        """
        Dynamic observable calculator layer
        Must be in its own function to isolate the variable name space
        If you have this in the __init__, the "observable" variable colides with any others in the list, causing a
        the wrong property to be fetched.
        """
        return lambda: self._compute_observable(observable)
    @property
    def observables(self):
        """List of observables this :class:`MultiPhaseAnalyzer` can generate"""
        return self._observables
    @property
    def phases(self):
        """List of implemented :class:`YankPhaseAnalyzer`'s objects this :class:`MultiPhaseAnalyzer` is tied to"""
        return self._phases
    @property
    def names(self):
        """
        Unique list of string names identifying this phase. If this :class:`MultiPhaseAnalyzer` is combined with
        another, its possible that new names will be generated unique to that :class:`MultiPhaseAnalyzer`, but will
        still reference the same phase.
        When in doubt, use :func:`MultiPhaseAnalyzer.phases` to get the actual phase objects.
        """
        return self._names
    @property
    def signs(self):
        """
        List of signs that are used by the :class:`MultiPhaseAnalyzer` to
        """
        return self._signs
    def _combine_phases(self, other, operator='+'):
        """
        Function to combine the phases regardless of operator to reduce code duplication. Creates a new
        :class:`MultiPhaseAnalyzer` object based on the combined phases of the other. Accepts either a
        :class:`YankPhaseAnalyzer` or a :class:`MultiPhaseAnalyzer`.
        If the names have collision, they are re-named with an extra digit at the end.
        Parameters
        ----------
        other : :class:`MultiPhaseAnalyzer` or :class:`YankPhaseAnalyzer`
        operator : sign of the operator connecting the two objects
        Returns
        -------
        output : :class:`MultiPhaseAnalyzer`
            New :class:`MultiPhaseAnalyzer` where the phases are the combined list of the individual phases from each
            component. Because the memory pointers to the individual phases are the same, changing any
            single :class:`YankPhaseAnalyzer`'s
            reference_state objects updates all :class:`MultiPhaseAnalyzer` objects they are tied to
        """
        phases = []
        names = []
        signs = []
        # create copies
        phases.extend(self.phases)
        names.extend(self.names)
        signs.extend(self.signs)
        if isinstance(other, MultiPhaseAnalyzer):
            new_phases = other.phases
            new_signs = other.signs
            new_names = other.names
            final_new_names = []
            for name in new_names:
                other_names = [n for n in new_names if n != name]
                final_new_names.append(generate_phase_name(name, other_names + names))
            names.extend(final_new_names)
            for new_sign in new_signs:
                if (operator == '-' and new_sign == '+') or (operator == '+' and new_sign == '-'):
                    signs.append('-')
                else:
                    signs.append('+')
            signs.extend(new_signs)
            phases.extend(new_phases)
        elif isinstance(other, YankPhaseAnalyzer):
            names.append(generate_phase_name(other.name, names))
            if (operator == '-' and other._sign == '+') or (operator == '+' and other._sign == '-'):
                signs.append('-')
            else:
                signs.append('+')
            other._sign = '+'  # Recast to positive if negated
            phases.append(other)
        else:
            baseerr = "cannot {} 'MultiPhaseAnalyzer' and '{}' objects"
            if operator == '+':
                err = baseerr.format('add', type(other))
            else:
                err = baseerr.format('subtract', type(other))
            raise TypeError(err)
        phase_pass = {'phases': phases, 'signs': signs, 'names': names}
        return MultiPhaseAnalyzer(phase_pass)
    def __add__(self, other):
        return self._combine_phases(other, operator='+')
    def __sub__(self, other):
        return self._combine_phases(other, operator='-')
    def __neg__(self):
        """
        Return a SHALLOW copy of self with negated signs so that the phase objects all still point to the same
        objects
        """
        new_signs = []
        for sign in self._signs:
            if sign == '+':
                new_signs.append('-')
            else:
                new_signs.append('+')
        # return a *shallow* copy of self with the signs reversed
        output = copy.copy(self)
        output._signs = new_signs
        return output
    def __str__(self):
        """Simplified string output"""
        header = "MultiPhaseAnalyzer<{}>"
        output_string = ""
        for phase_name, sign in zip(self.names, self.signs):
            if output_string == "" and sign == '-':
                output_string += '{}{} '.format(sign, phase_name)
            elif output_string == "":
                output_string += '{} '.format(phase_name)
            else:
                output_string += '{} {} '.format(sign, phase_name)
        return header.format(output_string)
    def __repr__(self):
        """Generate a detailed representation of the MultiPhase"""
        header = "MultiPhaseAnalyzer <\n{}>"
        output_string = ""
        for phase, phase_name, sign in zip(self.phases, self.names, self.signs):
            if output_string == "" and sign == '-':
                output_string += '{}{} ({})\n'.format(sign, phase_name, phase)
            elif output_string == "":
                output_string += '{} ({})\n'.format(phase_name, phase)
            else:
                output_string += '    {} {} ({})\n'.format(sign, phase_name, phase)
        return header.format(output_string)
    def _compute_observable(self, observable_name):
        """
        Helper function to compute arbitrary observable in both phases
        Parameters
        ----------
        observable_name : str
            Name of the observable as its defined in the ObservablesRegistry
        Returns
        -------
        observable_value
            The observable as its combined between all the phases
        """
        def prepare_phase_observable(single_phase):
            """Helper function to cast the observable in terms of observable's registry"""
            observable = getattr(single_phase, "get_" + observable_name)()
            if isinstance(single_phase, MultiPhaseAnalyzer):
                if observable_name in _ObservablesRegistry.observables_with_error():
                    observable_payload = {}
                    observable_payload['value'], observable_payload['error'] = observable
                else:
                    observable_payload = observable
            else:
                raise_registry_error = False
                if observable_name in _ObservablesRegistry.observables_with_error():
                    observable_payload = {}
                    if observable_name in _ObservablesRegistry.observables_defined_by_phase():
                        observable_payload['value'], observable_payload['error'] = observable
                    elif observable_name in _ObservablesRegistry.observables_defined_by_single_state():
                        observable_payload['value'] = observable[0][single_phase.reference_states[0]]
                        observable_payload['error'] = observable[1][single_phase.reference_states[0]]
                    elif observable_name in _ObservablesRegistry.observables_defined_by_two_states():
                        observable_payload['value'] = observable[0][single_phase.reference_states[0],
                                                                    single_phase.reference_states[1]]
                        observable_payload['error'] = observable[1][single_phase.reference_states[0],
                                                                    single_phase.reference_states[1]]
                    else:
                        raise_registry_error = True
                else:  # No error
                    if observable_name in _ObservablesRegistry.observables_defined_by_phase():
                        observable_payload = observable
                    elif observable_name in _ObservablesRegistry.observables_defined_by_single_state():
                        observable_payload = observable[single_phase.reference_states[0]]
                    elif observable_name in _ObservablesRegistry.observables_defined_by_two_states():
                        observable_payload = observable[single_phase.reference_states[0],
                                                        single_phase.reference_states[1]]
                    else:
                        raise_registry_error = True
                if raise_registry_error:
                    raise RuntimeError("You have requested an observable that is improperly registered in the "
                                       "ObservablesRegistry!")
            return observable_payload
        def modify_final_output(passed_output, payload, sign):
            if observable_name in _ObservablesRegistry.observables_with_error():
                if sign == '+':
                    passed_output['value'] += payload['value']
                else:
                    passed_output['value'] -= payload['value']
                if observable_name in _ObservablesRegistry.observables_with_error_adding_linear():
                    passed_output['error'] += payload['error']
                elif observable_name in _ObservablesRegistry.observables_with_error_adding_quadrature():
                    passed_output['error'] = (passed_output['error']**2 + payload['error']**2)**0.5
            else:
                if sign == '+':
                    passed_output += payload
                else:
                    passed_output -= payload
            return passed_output
        if observable_name in _ObservablesRegistry.observables_with_error():
            final_output = {'value': 0, 'error': 0}
        else:
            final_output = 0
        for phase, phase_sign in zip(self.phases, self.signs):
            phase_observable = prepare_phase_observable(phase)
            final_output = modify_final_output(final_output, phase_observable, phase_sign)
        if observable_name in _ObservablesRegistry.observables_with_error():
            # Cast output to tuple
            final_output = (final_output['value'], final_output['error'])
        return final_output 
[docs]def analyze_directory(source_directory):
    """
    Analyze contents of store files to compute free energy differences.
    This function is needed to preserve the old auto-analysis style of YANK. What it exactly does can be refined when
    more analyzers and simulations are made available. For now this function exposes the API.
    Parameters
    ----------
    source_directory : string
       The location of the simulation storage files.
    """
    analysis_script_path = os.path.join(source_directory, 'analysis.yaml')
    if not os.path.isfile(analysis_script_path):
        err_msg = 'Cannot find analysis.yaml script in {}'.format(source_directory)
        logger.error(err_msg)
        raise RuntimeError(err_msg)
    with open(analysis_script_path, 'r') as f:
        analysis = yaml.load(f)
    phase_names = [phase_name for phase_name, sign in analysis]
    data = dict()
    for phase_name, sign in analysis:
        phase_path = os.path.join(source_directory, phase_name + '.nc')
        phase = get_analyzer(phase_path)
        data[phase_name] = phase.analyze_phase()
        kT = phase.kT
    # Compute free energy and enthalpy
    DeltaF = 0.0
    dDeltaF = 0.0
    DeltaH = 0.0
    dDeltaH = 0.0
    for phase_name, sign in analysis:
        DeltaF -= sign * (data[phase_name]['DeltaF'] + data[phase_name]['DeltaF_standard_state_correction'])
        dDeltaF += data[phase_name]['dDeltaF']**2
        DeltaH -= sign * (data[phase_name]['DeltaH'] + data[phase_name]['DeltaF_standard_state_correction'])
        dDeltaH += data[phase_name]['dDeltaH']**2
    dDeltaF = np.sqrt(dDeltaF)
    dDeltaH = np.sqrt(dDeltaH)
    # Attempt to guess type of calculation
    calculation_type = ''
    for phase in phase_names:
        if 'complex' in phase:
            calculation_type = ' of binding'
        elif 'solvent1' in phase:
            calculation_type = ' of solvation'
    # Print energies
    logger.info("")
    logger.info("Free energy{}: {:16.3f} +- {:.3f} kT ({:16.3f} +- {:.3f} kcal/mol)".format(
        calculation_type, DeltaF, dDeltaF, DeltaF * kT / units.kilocalories_per_mole,
        dDeltaF * kT / units.kilocalories_per_mole))
    logger.info("")
    for phase in phase_names:
        logger.info("DeltaG {:<25} : {:16.3f} +- {:.3f} kT".format(phase, data[phase]['DeltaF'],
                                                                   data[phase]['dDeltaF']))
        if data[phase]['DeltaF_standard_state_correction'] != 0.0:
            logger.info("DeltaG {:<25} : {:25.3f} kT".format('restraint',
                                                             data[phase]['DeltaF_standard_state_correction']))
    logger.info("")
    logger.info("Enthalpy{}: {:16.3f} +- {:.3f} kT ({:16.3f} +- {:.3f} kcal/mol)".format(
        calculation_type, DeltaH, dDeltaH, DeltaH * kT / units.kilocalories_per_mole,
        dDeltaH * kT / units.kilocalories_per_mole)) 
# ==========================================
# HELPER FUNCTIONS FOR TRAJECTORY EXTRACTION
# ==========================================
# ==============================================================================
# Extract trajectory from NetCDF4 file
# ==============================================================================