Source code for yank.multistate.sams

#!/usr/local/bin/env python

# ==============================================================================
# MODULE DOCSTRING
# ==============================================================================

"""
SamsSampler
===========

Self-adjusted mixture sampling (SAMS), also known as optimally-adjusted mixture sampling.

This implementation uses stochastic approximation to allow one or more replicas to sample the whole range of thermodynamic states
for rapid online computation of free energies.

COPYRIGHT

Written by John D. Chodera <john.chodera@choderalab.org> while at Memorial Sloan Kettering Cancer Center.

LICENSE

This code is licensed under the latest available version of the MIT License.

"""

import logging
import numpy as np
import openmmtools as mmtools
from scipy.special import logsumexp

from .. import mpi
from .multistatesampler import MultiStateSampler
from .multistatereporter import MultiStateReporter
from .multistateanalyzer import MultiStateSamplerAnalyzer

logger = logging.getLogger(__name__)

# ==============================================================================
# PARALLEL TEMPERING
# ==============================================================================


[docs]class SAMSSampler(MultiStateSampler): """Self-adjusted mixture sampling (SAMS), also known as optimally-adjusted mixture sampling. This class provides a facility for self-adjusted mixture sampling simulations. One or more replicas use the method of expanded ensembles [1] to sample multiple thermodynamic states within each replica, with log weights for each thermodynamic state adapted on the fly [2] to achieve the desired target probabilities for each state. Attributes ---------- log_target_probabilities : array-like log_target_probabilities[state_index] is the log target probability for state ``state_index`` state_update_scheme : str Thermodynamic state sampling scheme. One of ['global-jump', 'local-jump', 'restricted-range'] locality : int Number of neighboring states on either side to consider for local update schemes update_stages : str Number of stages to use for update. One of ['one-stage', 'two-stage'] weight_update_method : str Method to use for updating log weights in SAMS. One of ['optimal', 'rao-blackwellized'] adapt_target_probabilities : bool If True, target probabilities will be adapted to achieve minimal thermodynamic length between terminal thermodynamic states. gamma0 : float, optional, default=0.0 Initial weight adaptation rate. logZ_guess : array-like of shape [n_states] of floats, optional, default=None Initial guess for logZ for all states, if available. References ---------- [1] Lyubartsev AP, Martsinovski AA, Shevkunov SV, and Vorontsov-Velyaminov PN. New approach to Monte Carlo calculation of the free energy: Method of expanded ensembles. JCP 96:1776, 1992 http://dx.doi.org/10.1063/1.462133 [2] Tan, Z. Optimally adjusted mixture sampling and locally weighted histogram analysis, Journal of Computational and Graphical Statistics 26:54, 2017. http://dx.doi.org/10.1080/10618600.2015.1113975 Examples -------- SAMS simulation of alanine dipeptide in implicit solvent at different temperatures. Create the system: >>> import math >>> from simtk import unit >>> from openmmtools import testsystems, states, mcmc >>> testsystem = testsystems.AlanineDipeptideVacuum() Create thermodynamic states for parallel tempering with exponentially-spaced schedule: >>> n_replicas = 3 # Number of temperature replicas. >>> T_min = 298.0 * unit.kelvin # Minimum temperature. >>> T_max = 600.0 * unit.kelvin # Maximum temperature. >>> temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(nreplicas-1)) - 1.0) / (math.e - 1.0) ... for i in range(n_replicas)] >>> thermodynamic_states = [states.ThermodynamicState(system=testsystem.system, temperature=T) ... for T in temperatures] Initialize simulation object with options. Run with a GHMC integrator: >>> move = mcmc.GHMCMove(timestep=2.0*unit.femtoseconds, n_steps=50) >>> simulation = SAMSSampler(mcmc_moves=move, number_of_iterations=2, >>> state_update_scheme='restricted-range', locality=5, >>> update_stages='two-stage', flatness_threshold=0.2, >>> weight_update_method='rao-blackwellized', >>> adapt_target_probabilities=False) Create a single-replica SAMS simulation bound to a storage file and run: >>> storage_path = tempfile.NamedTemporaryFile(delete=False).name + '.nc' >>> reporter = MultiStateReporter(storage_path, checkpoint_interval=1) >>> simulation.create(thermodynamic_states=thermodynamic_states, >>> sampler_states=[states.SamplerState(testsystem.positions)], >>> storage=reporter) >>> simulation.run() # This runs for a maximum of 2 iterations. >>> simulation.iteration 2 >>> simulation.run(n_iterations=1) >>> simulation.iteration 2 To resume a simulation from an existing storage file and extend it beyond the original number of iterations. >>> del simulation >>> simulation = SAMSSampler.from_storage(reporter) >>> simulation.extend(n_iterations=1) >>> simulation.iteration 3 You can extract several information from the NetCDF file using the Reporter class while the simulation is running. This reads the SamplerStates of every run iteration. >>> reporter = MultiStateReporter(storage=storage_path, open_mode='r', checkpoint_interval=1) >>> sampler_states = reporter.read_sampler_states(iteration=range(1, 4)) >>> len(sampler_states) 3 >>> sampler_states[-1].positions.shape # Alanine dipeptide has 22 atoms. (22, 3) Clean up. >>> os.remove(storage_path) See Also -------- ReplicaExchangeSampler """ _TITLE_TEMPLATE = ('Self-adjusted mixture sampling (SAMS) simulation using SAMSSampler ' 'class of yank.multistate on {}') def __init__(self, number_of_iterations=1, log_target_probabilities=None, state_update_scheme='global-jump', locality=5, update_stages='two-stage', flatness_threshold=0.2, weight_update_method='rao-blackwellized', adapt_target_probabilities=False, gamma0=1.0, logZ_guess=None, **kwargs): """Initialize a SAMS sampler. Parameters ---------- log_target_probabilities : array-like or None ``log_target_probabilities[state_index]`` is the log target probability for thermodynamic state ``state_index`` When converged, each state should be sampled with the specified log probability. If None, uniform probabilities for all states will be assumed. state_update_scheme : str, optional, default='global-jump' Specifies the scheme used to sample new thermodynamic states given fixed sampler states. One of ['global-jump', 'local-jump', 'restricted-range-jump'] ``global_jump`` will allow the sampler to access any thermodynamic state ``local-jump`` will propose a move to one of the local neighborhood states, and accept or reject. ``restricted-range`` will compute the probabilities for each of the states in the local neighborhood, increasing jump probability locality : int, optional, default=1 Number of neighboring states on either side to consider for local update schemes. update_stages : str, optional, default='two-stage' One of ['one-stage', 'two-stage'] ``one-stage`` will use the asymptotically optimal scheme throughout the entire simulation (not recommended due to slow convergence) ``two-stage`` will use a heuristic first stage to achieve flat histograms before switching to the asymptotically optimal scheme flatness_threshold : float, optiona, default=0.2 Histogram relative flatness threshold to use for first stage of two-stage scheme. weight_update_method : str, optional, default='rao-blackwellized' Method to use for updating log weights in SAMS. One of ['optimal', 'rao-blackwellized'] ``rao-blackwellized`` will update log free energy estimate for all states for which energies were computed ``optimal`` will use integral counts to update log free energy estimate of current state only adapt_target_probabilities : bool, optional, default=False If True, target probabilities will be adapted to achieve minimal thermodynamic length between terminal thermodynamic states. (EXPERIMENTAL) gamma0 : float, optional, default=0.0 Initial weight adaptation rate. logZ_guess : array-like of shape [n_states] of floats, optiona, default=None Initial guess for logZ for all states, if available. """ # Initialize multi-state sampler super(SAMSSampler, self).__init__(number_of_iterations=number_of_iterations, **kwargs) # Options self.log_target_probabilities = log_target_probabilities self.state_update_scheme = state_update_scheme self.locality = locality self.update_stages = update_stages self.flatness_threshold = flatness_threshold self.weight_update_method = weight_update_method self.adapt_target_probabilities = adapt_target_probabilities self.gamma0 = gamma0 self.logZ_guess = logZ_guess # Private variables # self._replica_neighbors[replica_index] is a list of states that form the neighborhood of ``replica_index`` self._replica_neighbors = None self._cached_state_histogram = None class _StoredProperty(MultiStateSampler._StoredProperty): @staticmethod def _state_update_scheme_validator(instance, scheme): supported_schemes = ['global-jump', 'local-jump', 'restricted-range-jump'] supported_schemes = ['global-jump'] # TODO: Eliminate this after release if scheme not in supported_schemes: raise ValueError("Unknown update scheme '{}'. Supported values " "are {}.".format(scheme, supported_schemes)) return scheme @staticmethod def _update_stages_validator(instance, scheme): supported_schemes = ['one-stage', 'two-stage'] if scheme not in supported_schemes: raise ValueError("Unknown update scheme '{}'. Supported values " "are {}.".format(scheme, supported_schemes)) return scheme @staticmethod def _weight_update_method_validator(instance, scheme): supported_schemes = ['optimal', 'rao-blackwellized'] if scheme not in supported_schemes: raise ValueError("Unknown update scheme '{}'. Supported values " "are {}.".format(scheme, supported_schemes)) return scheme @staticmethod def _adapt_target_probabilities_validator(instance, scheme): supported_schemes = [False] if scheme not in supported_schemes: raise ValueError("Unknown update scheme '{}'. Supported values " "are {}.".format(scheme, supported_schemes)) return scheme log_target_probabilities = _StoredProperty('log_target_probabilities', validate_function=None) state_update_scheme = _StoredProperty('state_update_scheme', validate_function=_StoredProperty._state_update_scheme_validator) locality = _StoredProperty('locality', validate_function=None) update_stages = _StoredProperty('update_stages', validate_function=_StoredProperty._update_stages_validator) flatness_threshold = _StoredProperty('flatness_threshold', validate_function=None) weight_update_method = _StoredProperty('weight_update_method', validate_function=_StoredProperty._weight_update_method_validator) adapt_target_probabilities = _StoredProperty('adapt_target_probabilities', validate_function=_StoredProperty._adapt_target_probabilities_validator) gamma0 = _StoredProperty('gamma0', validate_function=None) logZ_guess = _StoredProperty('logZ_guess', validate_function=None) def _initialize_stage(self): self._t0 = 0 # reference iteration to subtract if self.update_stages == 'one-stage': self._stage = 1 # start with asymptotically-optimal stage elif self.update_stages == 'two-stage': self._stage = 0 # start with rapid heuristic adaptation initial stage def _pre_write_create(self, thermodynamic_states: list, sampler_states: list, storage, **kwargs): """Initialize SAMS sampler. Parameters ---------- thermodynamic_states : list of openmmtools.states.ThermodynamicState Thermodynamic states to simulate, where one replica is allocated per state. Each state must have a system with the same number of atoms. sampler_states : list of openmmtools.states.SamplerState One or more sets of initial sampler states. The number of replicas is determined by the number of sampler states provided, and does not need to match the number of thermodynamic states. Most commonly, a single sampler state is provided. storage : str or Reporter If str: path to the storage file, checkpoint options are default If Reporter: Instanced :class:`Reporter` class, checkpoint information is read from In the future this will be able to take a Storage class as well. initial_thermodynamic_states : None or list or array-like of int of length len(sampler_states), optional, default: None. Initial thermodynamic_state index for each sampler_state. If no initial distribution is chosen, ``sampler_states`` are distributed between the ``thermodynamic_states`` following these rules: * If ``len(thermodynamic_states) == len(sampler_states)``: 1-to-1 distribution * If ``len(thermodynamic_states) > len(sampler_states)``: First and last state distributed first remaining ``sampler_states`` spaced evenly by index until ``sampler_states`` are depleted. If there is only one ``sampler_state``, then the only first ``thermodynamic_state`` will be chosen * If ``len(thermodynamic_states) < len(sampler_states)``, each ``thermodynamic_state`` receives an equal number of ``sampler_states`` until there are insufficient number of ``sampler_states`` remaining to give each ``thermodynamic_state`` an equal number. Then the rules from the previous point are followed. metadata : dict, optional Simulation metadata to be stored in the file. """ # Initialize replica-exchange simulation. super()._pre_write_create(thermodynamic_states, sampler_states, storage=storage, **kwargs) if self.state_update_scheme == 'global-jump': self.locality = None # override locality to be global if self.locality is not None: if self.locality < 1: raise Exception('locality must be >= 1') elif self.locality >= self.n_states: self.locality = None # Record current weight update stage self._initialize_stage() # Update log target probabilities if self.log_target_probabilities is None: self.log_target_probabilities = np.zeros([self.n_states], np.float64) - np.log(self.n_states) # log(1/n_states) #logger.debug('Setting log target probabilities: %s' % str(self.log_target_probabilities)) #logger.debug('Target probabilities: %s' % str(np.exp(self.log_target_probabilities))) # Record initial logZ estimates self._logZ = np.zeros([self.n_states], np.float64) if self.logZ_guess is not None: if len(self.logZ_guess) != self.n_states: raise Exception('Initial logZ_guess (dim {}) must have same number of states as n_states ({})'.format( len(self.logZ_guess), self.n_states)) self._logZ = np.array(self.logZ_guess, np.float64) # Update log weights self._update_log_weights() def _restore_sampler_from_reporter(self, reporter): super()._restore_sampler_from_reporter(reporter) self._cached_state_histogram = self._compute_state_histogram(reporter=reporter) logger.debug('Restored state histogram: {}'.format(self._cached_state_histogram)) data = reporter.read_online_analysis_data(self._iteration, 'logZ', 'stage', 't0') self._logZ = data['logZ'] self._stage = int(data['stage'][0]) self._t0 = int(data['t0'][0]) # Compute log weights from log target probability and logZ estimate self._update_log_weights() # Determine t0 self._update_stage() @mpi.on_single_node(rank=0, broadcast_result=False, sync_nodes=False) @mpi.delayed_termination def _report_iteration_items(self): super(SAMSSampler, self)._report_iteration_items() self._reporter.write_online_data_dynamic_and_static(self._iteration, logZ=self._logZ, stage=self._stage, t0=self._t0) # Split into which states and how many samplers are in each state # Trying to do histogram[replica_thermo_states] += 1 does not correctly handle multiple # replicas in the same state. states, counts = np.unique(self._replica_thermodynamic_states, return_counts=True) if self._cached_state_histogram is None: self._cached_state_histogram = np.zeros(self.n_states, dtype=int) self._cached_state_histogram[states] += counts @mpi.on_single_node(0, broadcast_result=True) def _mix_replicas(self): """Update thermodynamic states according to user-specified scheme.""" logger.debug("Updating thermodynamic states using %s scheme..." % self.state_update_scheme) # Reset storage to keep track of swap attempts this iteration. self._n_accepted_matrix[:, :] = 0 self._n_proposed_matrix[:, :] = 0 # Perform swap attempts according to requested scheme. # TODO: We may be able to refactor this to simply have different update schemes compute neighborhoods differently. # TODO: Can we allow "plugin" addition of new update schemes that can be registered externally? with mmtools.utils.time_it('Mixing of replicas'): # Initialize statistics. This matrix is modified by the jump function and used when updating the logZ estimates. replicas_log_P_k = np.zeros([self.n_replicas, self.n_states], np.float64) if self.state_update_scheme == 'global-jump': self._global_jump(replicas_log_P_k) elif self.state_update_scheme == 'local-jump': self._local_jump(replicas_log_P_k) elif self.state_update_scheme == 'restricted-range-jump': self._restricted_range_jump(replicas_log_P_k) else: raise Exception('Programming error: Unreachable code') # Determine fraction of swaps accepted this iteration. n_swaps_proposed = self._n_proposed_matrix.sum() n_swaps_accepted = self._n_accepted_matrix.sum() swap_fraction_accepted = 0.0 if n_swaps_proposed > 0: # TODO drop casting to float when dropping Python 2 support. swap_fraction_accepted = float(n_swaps_accepted) / n_swaps_proposed logger.debug("Accepted {}/{} attempted swaps ({:.1f}%)".format(n_swaps_accepted, n_swaps_proposed, swap_fraction_accepted * 100.0)) # Update logZ estimates self._update_logZ_estimates(replicas_log_P_k) # Update log weights based on target probabilities self._update_log_weights() def _local_jump(self, replicas_log_P_k): n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality for (replica_index, current_state_index) in enumerate(self._replica_thermodynamic_states): u_k = np.zeros([n_states], np.float64) log_P_k = np.zeros([n_states], np.float64) # Determine current neighborhood. neighborhood = self._neighborhood() neighborhood_size = len(neighborhood) # Propose a move from the current neighborhood. proposed_state_index = np.random.choice(neighborhood, p=np.ones([neighborhood_size], np.float64) / float(neighborhood_size)) # Determine neighborhood for proposed state. proposed_neighborhood = self._neighborhood(proposed_state_index) proposed_neighborhood_size = len(proposed_neighborhood) # Compute state log weights. log_Gamma_j_L = - float(proposed_neighborhood_size) # log probability of proposing return log_Gamma_L_j = - float(neighborhood_size) # log probability of proposing new state L = current_state_index # Compute potential for all states in neighborhood for j in neighborhood: u_k[j] = self._energy_thermodynamic_states[replica_index, j] # Compute log of probability of selecting each state in neighborhood for j in neighborhood: if j != L: log_P_k[j] = log_Gamma_L_j + min(0.0, log_Gamma_j_L - log_Gamma_L_j + (self.log_weights[j] - u_k[j]) - (self.log_weights[L] - u_k[L])) P_k = np.zeros([n_states], np.float64) P_k[neighborhood] = np.exp(log_P_k[neighborhood]) # Compute probability to return to current state L P_k[L] = 0.0 P_k[L] = 1.0 - P_k[neighborhood].sum() log_P_k[L] = np.log(P_k[L]) # Update context. new_state_index = np.random.choice(neighborhood, p=P_k[neighborhood]) self._replica_thermodynamic_states[replica_index] = new_state_index # Accumulate statistics replicas_log_P_k[replica_index,:] = log_P_k[:] self._n_proposed_matrix[current_state_index, neighborhood] += 1 self._n_accepted_matrix[current_state_index, new_state_index] += 1 def _global_jump(self, replicas_log_P_k): """ Global jump scheme. This method is described after Eq. 3 in [2] """ n_replica, n_states = self.n_replicas, self.n_states for replica_index, current_state_index in enumerate(self._replica_thermodynamic_states): neighborhood = self._neighborhood(current_state_index) # Compute unnormalized log probabilities for all thermodynamic states. log_P_k = np.zeros([n_states], np.float64) for state_index in neighborhood: u_k = self._energy_thermodynamic_states[replica_index, :] log_P_k[state_index] = - u_k[state_index] + self.log_weights[state_index] log_P_k -= logsumexp(log_P_k) # Update sampler Context to current thermodynamic state. P_k = np.exp(log_P_k[neighborhood]) new_state_index = np.random.choice(neighborhood, p=P_k) self._replica_thermodynamic_states[replica_index] = new_state_index # Accumulate statistics. replicas_log_P_k[replica_index,:] = log_P_k[:] self._n_proposed_matrix[current_state_index, neighborhood] += 1 self._n_accepted_matrix[current_state_index, new_state_index] += 1 def _restricted_range_jump(self, replicas_log_P_k): # TODO: This has an major bug in that we also need to compute energies in `proposed_neighborhood`. # I'm working on a way to make this work. n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality logger.debug('Using restricted range jump with locality %s' % str(self.locality)) for (replica_index, current_state_index) in enumerate(self._replica_thermodynamic_states): u_k = self._energy_thermodynamic_states[replica_index, :] log_P_k = np.zeros([n_states], np.float64) # Propose new state from current neighborhood. neighborhood = self._neighborhood(current_state_index) logger.debug(' Current state : %d' % current_state_index) logger.debug(' Neighborhood : %s' % str(neighborhood)) logger.debug(' Relative u_k : %s' % str(u_k[neighborhood] - u_k[current_state_index])) log_P_k[neighborhood] = self.log_weights[neighborhood] - u_k[neighborhood] log_P_k[neighborhood] -= logsumexp(log_P_k[neighborhood]) logger.debug(' Neighborhood log_P_k: %s' % str(log_P_k[neighborhood])) P_k = np.exp(log_P_k[neighborhood]) logger.debug(' Neighborhood P_k : %s' % str(P_k)) proposed_state_index = np.random.choice(neighborhood, p=P_k) logger.debug(' Proposed state : %d' % proposed_state_index) # Determine neighborhood of proposed state. proposed_neighborhood = self._neighborhood(proposed_state_index) logger.debug(' Proposed neighborhood : %s' % str(proposed_neighborhood)) # Accept or reject. log_P_accept = logsumexp(self.log_weights[neighborhood] - u_k[neighborhood]) - logsumexp(self.log_weights[proposed_neighborhood] - u_k[proposed_neighborhood]) logger.debug(' log_P_accept : %f' % log_P_accept) logger.debug(' logsumexp(g[forward] - u[forward]) : %f' % logsumexp(self.log_weights[neighborhood] - u_k[neighborhood])) logger.debug(' logsumexp(g[reverse] - u[reverse]) : %f' % logsumexp(self.log_weights[proposed_neighborhood] - u_k[proposed_neighborhood])) new_state_index = current_state_index if (log_P_accept >= 0.0) or (np.random.rand() < np.exp(log_P_accept)): new_state_index = proposed_state_index logger.debug(' new_state_index : %d' % new_state_index) self._replica_thermodynamic_states[replica_index] = new_state_index # Accumulate statistics replicas_log_P_k[replica_index,:] = log_P_k[:] self._n_proposed_matrix[current_state_index, neighborhood] += 1 self._n_accepted_matrix[current_state_index, new_state_index] += 1 @property def _state_histogram(self): """ Compute the histogram for the number of times each state has been visited. Returns ------- N_k : array-like of shape [n_states] of int N_k[state_index] is the number of times a replica has visited state ``state_index`` """ if self._cached_state_histogram is None: self._cached_state_histogram = self._compute_state_histogram() return self._cached_state_histogram def _compute_state_histogram(self, reporter=None): """ Compute state histogram from disk""" if reporter is None: reporter = self._reporter replica_thermodynamic_states = reporter.read_replica_thermodynamic_states() logger.debug('Read replica thermodynamic states: {}'.format(replica_thermodynamic_states)) n_k, _ = np.histogram(replica_thermodynamic_states, bins=np.arange(-0.5, self.n_states + 0.5)) return n_k def _update_stage(self): """ Determine which adaptation stage we're in by checking histogram flatness. """ # TODO: Make this a user option #flatness_criteria = 'minimum-visits' # DEBUG flatness_criteria = 'logZ-flatness' # DEBUG minimum_visits = 1 N_k = self._state_histogram logger.debug(' state histogram counts ({} total): {}'.format(self._cached_state_histogram.sum(), self._cached_state_histogram)) if (self.update_stages == 'two-stage') and (self._stage == 0): advance = False if N_k.sum() == 0: # No samples yet; don't do anything. return if flatness_criteria == 'minimum-visits': # Advance if every state has been visited at least once if np.all(N_k >= minimum_visits): advance = True elif flatness_criteria == 'flatness-threshold': # Check histogram flatness empirical_pi_k = N_k[:] / N_k.sum() pi_k = np.exp(self.log_target_probabilities) relative_error_k = np.abs(pi_k - empirical_pi_k) / pi_k if np.all(relative_error_k < self.flatness_threshold): advance = True elif flatness_criteria == 'logZ-flatness': # TODO: Advance to asymptotically optimal scheme when logZ update fractional counts per state exceed threshold # for all states. criteria = abs(self._logZ / self.gamma0) > self.flatness_threshold logger.debug('logZ-flatness criteria met (%d total): %s' % (np.sum(criteria), str(np.array(criteria, 'i1')))) if np.all(criteria): advance = True else: raise ValueError("Unknown flatness_criteria %s" % flatness_criteria) if advance or ((self._t0 > 0) and (self._iteration > self._t0)): # Histograms are sufficiently flat; switch to asymptotically optimal scheme self._stage = 1 # asymptotically optimal # TODO: On resuming, we need to recompute or restore t0, or use some other way to compute it self._t0 = self._iteration - 1 def _update_logZ_estimates(self, replicas_log_P_k): """ Update the logZ estimates according to selected SAMS update method References ---------- [1] http://www.stat.rutgers.edu/home/ztan/Publication/SAMS_redo4.pdf """ logger.debug('Updating logZ estimates...') # Store log weights used at the beginning of this iteration self._reporter.write_online_analysis_data(self._iteration, log_weights=self.log_weights) # Retrieve target probabilities log_pi_k = self.log_target_probabilities pi_k = np.exp(self.log_target_probabilities) #logger.debug(' log target probabilities log_pi_k: %s' % str(log_pi_k)) #logger.debug(' target probabilities pi_k: %s' % str(pi_k)) # Update which stage we're in, checking histogram flatness self._update_stage() logger.debug(' stage: %s' % self._stage) # Update logZ estimates from all replicas for (replica_index, state_index) in enumerate(self._replica_thermodynamic_states): logger.debug(' Replica %d state %d' % (replica_index, state_index)) # Compute attenuation factor gamma beta_factor = 0.8 pi_star = pi_k.min() t = float(self._iteration) if self._stage == 0: # initial stage gamma = self.gamma0 * min(pi_star, t**(-beta_factor)) # Eq. 15 of [1] elif self._stage == 1: gamma = self.gamma0 * min(pi_star, (t - self._t0 + self._t0**beta_factor)**(-1)) # Eq. 15 of [1] else: raise Exception('stage {} unknown'.format(self._stage)) #logger.debug(' gamma: %s' % gamma) # Update online logZ estimate if self.weight_update_method == 'optimal': # Based on Eq. 9 of Ref. [1] logZ_update = gamma * np.exp(-log_pi_k[state_index]) #logger.debug(' optimal logZ increment: %s' % str(logZ_update)) self._logZ[state_index] += logZ_update elif self.weight_update_method == 'rao-blackwellized': # Based on Eq. 12 of Ref [1] # TODO: This has to be the previous state index and log_P_k used before update; store neighborhood? # TODO: Can we use masked arrays for this purpose? log_P_k = replicas_log_P_k[replica_index,:] neighborhood = np.where(self._neighborhoods[replica_index,:])[0] # compact list of states defining neighborhood #logger.debug(' using neighborhood: %s' % str(neighborhood)) #logger.debug(' using log_P_k : %s' % str(log_P_k[neighborhood])) #logger.debug(' using log_pi_k: %s' % str(log_pi_k[neighborhood])) logZ_update = gamma * np.exp(log_P_k[neighborhood] - log_pi_k[neighborhood]) #logger.debug(' Rao-Blackwellized logZ increment: %s' % str(logZ_update)) self._logZ[neighborhood] += logZ_update else: raise Exception('Programming error: Unreachable code') # Subtract off logZ[0] to prevent logZ from growing without bound once we reach the asymptotically optimal stage if self._stage == 1: # asymptotically optimal or one-stage self._logZ[:] -= self._logZ[0] # Format logZ msg = ' logZ: [' for i, val in enumerate(self._logZ): if i > 0: msg += ', ' msg += '%6.1f' % val msg += ']' logger.debug(msg) # Store gamma self._reporter.write_online_analysis_data(self._iteration, gamma=gamma) def _update_log_weights(self): """ Update the log weights based on current online logZ estimates """ # TODO: Add option to adapt target probabilities as well # TODO: If target probabilities are adapted, we need to store them as well self.log_weights = self.log_target_probabilities[:] - self._logZ[:]
[docs]class SAMSAnalyzer(MultiStateSamplerAnalyzer): """ The SAMSAnalyzer is the analyzer for a simulation generated from a SAMSSampler simulation. See Also -------- ReplicaExchangeAnalyzer PhaseAnalyzer """ pass
# ============================================================================== # MAIN AND TESTS # ============================================================================== if __name__ == "__main__": import doctest doctest.testmod()