##############################################################################
# MDTraj: A Python Library for Loading, Saving, and Manipulating
#         Molecular Dynamics Trajectories.
# Copyright 2012-2013 Stanford University and the Authors
#
# Authors: Robert McGibbon
# Contributors: Kyle A Beauchamp
#
# MDTraj is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 2.1
# of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with MDTraj. If not, see <http://www.gnu.org/licenses/>.
##############################################################################
##############################################################################
# imports
##############################################################################
from __future__ import print_function, division
import ast
import os
import sys
import functools
import numpy as np
from numpy.testing import (assert_allclose, assert_almost_equal,
  assert_approx_equal, assert_array_almost_equal, assert_array_almost_equal_nulp,
  assert_array_equal, assert_array_less, assert_array_max_ulp, assert_equal,
  assert_raises, assert_string_equal, assert_warns)
from numpy.testing.decorators import skipif, slow
from nose.tools import ok_, eq_, raises
from nose import SkipTest
from pkg_resources import resource_filename
# py2/3 compatibility
from mdtraj.utils.six import iteritems, integer_types, PY2
# if the system doesn't have scipy, we'd like
# this package to still work:
# we'll just redefine isspmatrix as a function that always returns
# false
try:
    from scipy.sparse import isspmatrix
except ImportError:
    isspmatrix = lambda x: False
try:
    # need special logic to check for equality of pandas DataFrames.
    # but this is only relevant if the user has pandas installed
    import pandas as pd
except ImportError:
    pass
__all__ = ['assert_allclose', 'assert_almost_equal', 'assert_approx_equal',
           'assert_array_almost_equal', 'assert_array_almost_equal_nulp',
           'assert_array_equal', 'assert_array_less', 'assert_array_max_ulp',
           'assert_equal', 'assert_raises',
           'assert_string_equal', 'assert_warns', 'get_fn', 'eq',
           'assert_dict_equal', 'assert_sparse_matrix_equal',
           'expected_failure', 'SkipTest', 'ok_', 'eq_', 'raises', 'skipif',
           'slow']
##############################################################################
# functions
##############################################################################
[docs]def get_fn(name):
    """Get the full path to one of the reference files shipped for testing
    In the source distribution, these files are in ``MDTraj/testing/reference``,
    but on installation, they're moved to somewhere in the user's python
    site-packages directory.
    Parameters
    ----------
    name : str
        Name of the file to load (with respect to the reference/ folder).
    Examples
    --------
    >>> import mdtraj as md
    >>> t = md.load(get_fn('2EQQ.pdb'))
    >>> eq(t.n_frames, 20)    # this runs the assert, using the eq() func.
    """
    fn = resource_filename('mdtraj', os.path.join('testing', 'reference', name))
    if not os.path.exists(fn):
        raise ValueError('Sorry! %s does not exists. If you just '
            'added it, you\'ll have to re install' % fn)
    return fn 
[docs]def eq(o1, o2, decimal=6, err_msg=''):
    """Convenience function for asserting that two objects are equal to one another
    If the objects are both arrays or sparse matrices, this method will
    dispatch to an appropriate handler, which makes it a little bit more
    useful than just calling ``assert o1 == o2`` (which wont work for numpy
    arrays -- it returns an array of bools, not a single True or False)
    Parameters
    ----------
    o1 : object
        The first object
    o2 : object
        The second object
    decimal : int
        If the two objects are floats or arrays of floats, they'll be checked for
        equality up to this decimal place.
    err_msg : str
        Custom error message
    Returns
    -------
    passed : bool
        True if the tests pass. If the tests doesn't pass, since the AssertionError will be raised
    Raises
    ------
    AssertionError
        If the tests fail
    """
    if isinstance(o1, integer_types) and isinstance(o2, integer_types) and PY2:
        eq_(long(o1), long(o2))
        return
    assert (type(o1) is type(o2)), 'o1 and o2 not the same type: %s %s' % (type(o1), type(o2))
    if isinstance(o1, dict):
        assert_dict_equal(o1, o1, decimal)
    elif isinstance(o1, float):
        np.testing.assert_almost_equal(o1, o2, decimal)
    elif isspmatrix(o1):
        assert_sparse_matrix_equal(o1, o1, decimal)
    elif isinstance(o1, np.ndarray):
        if o1.dtype.kind == 'f' or o2.dtype.kind == 'f':
            # compare floats for almost equality
            assert_array_almost_equal(o1, o2, decimal, err_msg=err_msg)
        elif o1.dtype.type == np.core.records.record:
            # if its a record array, we need to comparse each term
            assert o1.dtype.names == o2.dtype.names
            for name in o1.dtype.names:
                eq(o1[name], o2[name], decimal=decimal, err_msg=err_msg)
        else:
            # compare everything else (ints, bools) for absolute equality
            assert_array_equal(o1, o2, err_msg=err_msg)
    elif 'pandas' in sys.modules and isinstance(o1, pd.DataFrame):
        # pandas dataframes are basically like dictionaries of numpy arrayss
        assert_dict_equal(o1, o2, decimal=decimal)
    elif isinstance(o1, ast.AST) and isinstance(o2, ast.AST):
        eq_(ast.dump(o1), ast.dump(o2))
    # probably these are other specialized types
    # that need a special check?
    else:
        eq_(o1, o2, msg=err_msg)
    return True 
def assert_dict_equal(t1, t2, decimal=6):
    """Assert two dicts are equal. This method should actually
    work for any dict of numpy arrays/objects
    Parameters
    ----------
    t1 : object
    t2 : object
    decimal : int
        Number of decimal places to check, for arrays inside the dicts
    """
    # make sure the keys are the same
    eq_(list(t1.keys()), list(t2.keys()))
    for key, val in iteritems(t1):
        # compare numpy arrays using numpy.testing
        if isinstance(val, np.ndarray) or ('pandas' in sys.modules and isinstance(t1, pd.DataFrame)):
            if val.dtype.kind ==  'f':
                # compare floats for almost equality
                assert_array_almost_equal(val, t2[key], decimal)
            else:
                # compare everything else (ints, bools) for absolute equality
                assert_array_equal(val, t2[key])
        else:
            eq_(val, t2[key])
def assert_sparse_matrix_equal(m1, m2, decimal=6):
    """Assert two scipy.sparse matrices are equal.
    Parameters
    ----------
    m1 : sparse_matrix
    m2 : sparse_matrix
    decimal : int
        Number of decimal places to check.
    """
    # both are sparse matricies
    assert isspmatrix(m1)
    assert isspmatrix(m1)
    # make sure they have the same format
    eq_(m1.format, m2.format)
    # even though its called assert_array_almost_equal, it will
    # work for scalars
    assert_array_almost_equal((m1 - m2).sum(), 0, decimal=decimal)
# decorator to mark tests as expected failure
def expected_failure(test):
    @functools.wraps(test)
    def inner(*args, **kwargs):
        try:
            test(*args, **kwargs)
        except BaseException:
            raise SkipTest
        else:
            raise AssertionError('Failure expected')
    return inner
# decorator to skip tests
def skip(reason):
    def wrap(test):
        @functools.wraps(test)
        def inner(*args, **kwargs):
            raise SkipTest
            print("After f(*args)")
        return inner
    return wrap