Source code for mdtraj.testing.testing

##############################################################################
# MDTraj: A Python Library for Loading, Saving, and Manipulating
#         Molecular Dynamics Trajectories.
# Copyright 2012-2013 Stanford University and the Authors
#
# Authors: Robert McGibbon
# Contributors: Kyle A Beauchamp
#
# MDTraj is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 2.1
# of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with MDTraj. If not, see <http://www.gnu.org/licenses/>.
##############################################################################


##############################################################################
# imports
##############################################################################

from __future__ import print_function, division
import ast
import os
import sys
import functools
import numpy as np
from numpy.testing import (assert_allclose, assert_almost_equal,
  assert_approx_equal, assert_array_almost_equal, assert_array_almost_equal_nulp,
  assert_array_equal, assert_array_less, assert_array_max_ulp, assert_equal,
  assert_raises, assert_string_equal, assert_warns)
from numpy.testing.decorators import skipif, slow
from nose.tools import ok_, eq_, raises
from nose import SkipTest
from pkg_resources import resource_filename

# py2/3 compatibility
from mdtraj.utils.six import iteritems, integer_types, PY2

# if the system doesn't have scipy, we'd like
# this package to still work:
# we'll just redefine isspmatrix as a function that always returns
# false
try:
    from scipy.sparse import isspmatrix
except ImportError:
    isspmatrix = lambda x: False

try:
    # need special logic to check for equality of pandas DataFrames.
    # but this is only relevant if the user has pandas installed
    import pandas as pd
except ImportError:
    pass

__all__ = ['assert_allclose', 'assert_almost_equal', 'assert_approx_equal',
           'assert_array_almost_equal', 'assert_array_almost_equal_nulp',
           'assert_array_equal', 'assert_array_less', 'assert_array_max_ulp',
           'assert_equal', 'assert_raises',
           'assert_string_equal', 'assert_warns', 'get_fn', 'eq',
           'assert_dict_equal', 'assert_sparse_matrix_equal',
           'expected_failure', 'SkipTest', 'ok_', 'eq_', 'raises', 'skipif',
           'slow']

##############################################################################
# functions
##############################################################################


[docs]def get_fn(name): """Get the full path to one of the reference files shipped for testing In the source distribution, these files are in ``MDTraj/testing/reference``, but on installation, they're moved to somewhere in the user's python site-packages directory. Parameters ---------- name : str Name of the file to load (with respect to the reference/ folder). Examples -------- >>> import mdtraj as md >>> t = md.load(get_fn('2EQQ.pdb')) >>> eq(t.n_frames, 20) # this runs the assert, using the eq() func. """ fn = resource_filename('mdtraj', os.path.join('testing', 'reference', name)) if not os.path.exists(fn): raise ValueError('Sorry! %s does not exists. If you just ' 'added it, you\'ll have to re install' % fn) return fn
[docs]def eq(o1, o2, decimal=6, err_msg=''): """Convenience function for asserting that two objects are equal to one another If the objects are both arrays or sparse matrices, this method will dispatch to an appropriate handler, which makes it a little bit more useful than just calling ``assert o1 == o2`` (which wont work for numpy arrays -- it returns an array of bools, not a single True or False) Parameters ---------- o1 : object The first object o2 : object The second object decimal : int If the two objects are floats or arrays of floats, they'll be checked for equality up to this decimal place. err_msg : str Custom error message Returns ------- passed : bool True if the tests pass. If the tests doesn't pass, since the AssertionError will be raised Raises ------ AssertionError If the tests fail """ if isinstance(o1, integer_types) and isinstance(o2, integer_types) and PY2: eq_(long(o1), long(o2)) return assert (type(o1) is type(o2)), 'o1 and o2 not the same type: %s %s' % (type(o1), type(o2)) if isinstance(o1, dict): assert_dict_equal(o1, o1, decimal) elif isinstance(o1, float): np.testing.assert_almost_equal(o1, o2, decimal) elif isspmatrix(o1): assert_sparse_matrix_equal(o1, o1, decimal) elif isinstance(o1, np.ndarray): if o1.dtype.kind == 'f' or o2.dtype.kind == 'f': # compare floats for almost equality assert_array_almost_equal(o1, o2, decimal, err_msg=err_msg) elif o1.dtype.type == np.core.records.record: # if its a record array, we need to comparse each term assert o1.dtype.names == o2.dtype.names for name in o1.dtype.names: eq(o1[name], o2[name], decimal=decimal, err_msg=err_msg) else: # compare everything else (ints, bools) for absolute equality assert_array_equal(o1, o2, err_msg=err_msg) elif 'pandas' in sys.modules and isinstance(o1, pd.DataFrame): # pandas dataframes are basically like dictionaries of numpy arrayss assert_dict_equal(o1, o2, decimal=decimal) elif isinstance(o1, ast.AST) and isinstance(o2, ast.AST): eq_(ast.dump(o1), ast.dump(o2)) # probably these are other specialized types # that need a special check? else: eq_(o1, o2, msg=err_msg) return True
def assert_dict_equal(t1, t2, decimal=6): """Assert two dicts are equal. This method should actually work for any dict of numpy arrays/objects Parameters ---------- t1 : object t2 : object decimal : int Number of decimal places to check, for arrays inside the dicts """ # make sure the keys are the same eq_(list(t1.keys()), list(t2.keys())) for key, val in iteritems(t1): # compare numpy arrays using numpy.testing if isinstance(val, np.ndarray) or ('pandas' in sys.modules and isinstance(t1, pd.DataFrame)): if val.dtype.kind == 'f': # compare floats for almost equality assert_array_almost_equal(val, t2[key], decimal) else: # compare everything else (ints, bools) for absolute equality assert_array_equal(val, t2[key]) else: eq_(val, t2[key]) def assert_sparse_matrix_equal(m1, m2, decimal=6): """Assert two scipy.sparse matrices are equal. Parameters ---------- m1 : sparse_matrix m2 : sparse_matrix decimal : int Number of decimal places to check. """ # both are sparse matricies assert isspmatrix(m1) assert isspmatrix(m1) # make sure they have the same format eq_(m1.format, m2.format) # even though its called assert_array_almost_equal, it will # work for scalars assert_array_almost_equal((m1 - m2).sum(), 0, decimal=decimal) # decorator to mark tests as expected failure def expected_failure(test): @functools.wraps(test) def inner(*args, **kwargs): try: test(*args, **kwargs) except BaseException: raise SkipTest else: raise AssertionError('Failure expected') return inner # decorator to skip tests def skip(reason): def wrap(test): @functools.wraps(test) def inner(*args, **kwargs): raise SkipTest print("After f(*args)") return inner return wrap