"""
Contains custom assertions for use with unittest and Schrodinger data
structures.
copyright (c) Schrodinger, LLC. All rights reserved.
"""
import argparse
import numbers
import shlex
import unittest
import _pytest.assertion.util
from pytest import approx
from schrodinger.infra import mm
from schrodinger.structutils import smiles
from schrodinger.structutils.rmsd import ConformerRmsd
__unittest = True
"""Keeps stack trace from including this module."""
_DEFAULT_TOLERANCE = 0.005
[docs]def assertSameStructure(st1, st2, smiles_generator=None):
"""
Check that two structures have the same connectivity.
:type st1: `schrodinger.structure.Structure`
:param st1: First structure
:type st2: `schrodinger.structure.Structure`
:param st2: Second structure
:type smiles_generator: `schrodinger.structutils.smiles.SmilesGenerator`
:param smiles_generator: Optional smiles generator with specified options.
If it is not specified, smiles are generated using
STEREO_FROM_ANNOTATION_AND_GEOM)
:raise AssertionError: number of atom/connectivity mismatch.
"""
assertSameNumberOfAtoms(st1, st2)
if not smiles_generator:
smiles_generator = smiles.SmilesGenerator(
stereo=smiles.STEREO_FROM_ANNOTATION_AND_GEOM)
smiles1 = smiles_generator.getSmiles(st1)
smiles2 = smiles_generator.getSmiles(st2)
if smiles1 != smiles2:
raise AssertionError("SMILES mismatch: {} != {}".format(
smiles1, smiles2))
[docs]def assertEqualShFiles(file1, file2):
"""
Check that two files have the same Schrodinger command that are starting
with "${SCHRODINGER}/run". The files to be compared should be the ones
written by Appframework2 when a write() function is called.
:type file1: str
:param file1: File containing the first command
:type file2: str
:param file2: File containing the second command
:rtype: None
:raise: AssertionError: commands different
"""
# Compare commands in 2 files
with open(file1) as ref_cmd_fd:
ref_cmd = ref_cmd_fd.read().strip()
with open(file2) as cmd_fd:
cmd = cmd_fd.read().strip()
# Ignore "${SCHRODINGER}/run", as the path separator will be OS specific
# Strip all chars till the first space and compare
if ref_cmd[ref_cmd.index(' '):] != cmd[cmd.index(' '):]:
raise AssertionError(
f"Commands did not match:\nFrom {file1}:\n{ref_cmd}\n"
f"From{file2}:\n{cmd}")
[docs]def assertEqualCommandFiles(file1,
file2,
arg_parser,
skip_count=0,
ignore_options=None):
"""
Check if commands in two files are same. The files should have exactly one
line which is the command to be compared. Since an ArgumentParser is used to
compare the commands, ideally this function should be used to compare just
the arguments and not the command name.
:param file1: Path to command file1
:type file1: str
:param file2: Path to command file2
:type file2: str
:param arg_parser: Function that takes a list of command arguments and
return an argparse.Namespace or a tuple (argparse.Namespace, list) if
the arg_parser is a ArgumentParser.parse_known_args()
:type arg_parser: func(list)
:param skip_count: Skip items at the begining of the command. Default is 0.
Can be used to skip the program name in the command.
:type skip_count: int
:param ignore_options: List of options to be ignored while comparing the
commands. Default is not to ignore anything.
:type ignore_options: list(options)
"""
args1 = []
with open(file1) as fh:
args1 = shlex.split(fh.read())
args2 = []
with open(file2) as fh:
args2 = shlex.split(fh.read())
assertEqualCommandArgs(args1[skip_count:], args2[skip_count:], arg_parser,
ignore_options)
[docs]def assertEqualCommandArgs(args1, args2, arg_parser, ignore_options=None):
"""
Check if two commands passed as list of arguments are same.
:param args1: Argument list1
:type args1: list
:param args2: Argument list2
:type args2: list
:param arg_parser: Function that takes a list of command arguments and
return an argparse.Namespace or a tuple (argparse.Namespace, list) if
the arg_parser is a ArgumentParser.parse_known_args()
:type arg_parser: func(list)
:param ignore_options: List of options to be ignored while comparing the
commands. Default is not to ignore anything.
:type ignore_options: list(options)
"""
def get_namespace(namespace_or_tuple):
if type(namespace_or_tuple) is argparse.Namespace:
return namespace_or_tuple
if type(namespace_or_tuple) is tuple and type(
namespace_or_tuple[0]) is argparse.Namespace:
return namespace_or_tuple[0]
raise ValueError(
f'Invalid return type by arg_parser:{type(namespace_or_tuple)}')
namespace1 = get_namespace(arg_parser(args1))
namespace2 = get_namespace(arg_parser(args2))
command_dict1 = vars(namespace1)
command_dict2 = vars(namespace2)
assertEqualDicts(command_dict1, command_dict2, ignore_options)
[docs]def assertEqualDicts(dict1, dict2, ignore_keys=None, tolerance=None):
"""
Compares two python dicts.
:param dict1: First dict
:type dict1: dict
:param dict2: Second dict
:type dict2: dict
:param ignore_keys: List of keys to be ignored while comparing the
dicts. Default is not to ignore anything.
:type ignore_keys: list(keys)
:param tolerance: Tolerance for comparing fractional numeric values. Pass
None to always compare with `!=`.
:type tolerance: float or None
"""
if ignore_keys == None:
ignore_keys = []
# Check for the length of the dicts only if there is nothing to ignore
if not ignore_keys and len(dict1) != len(dict2):
raise AssertionError(f'dict lengths are not equal:\n{dict1}\n{dict2}\n')
def compare_dicts(dict1, dict2, ignore_keys):
for key in dict1:
if key in ignore_keys:
continue
if key not in dict2:
raise AssertionError(f'key {key} not present in both dicts')
val1 = dict1[key]
val2 = dict2[key]
if tolerance and (_is_fractional(val1) or _is_fractional(val2)):
assert val1 == approx(val2, abs=tolerance), key
else:
assert val1 == val2, key
# Compare all keys of dict1 with dict2 and vice-versa to check if they
# exactly have same set of keys
compare_dicts(dict1, dict2, ignore_keys)
compare_dicts(dict2, dict1, ignore_keys)
def _is_fractional(i):
"""
Return True if i is a real or rational number (but not an integer).
"""
return isinstance(i, numbers.Real) and not isinstance(i, numbers.Integral)
[docs]class StructureAssertionsTestCase(unittest.TestCase):
"""
"Convenience" class to allow structural assertions to be called
similarly to built in assertions.
"""
[docs] def assertSameStructure(self, st1, st2):
assertSameStructure(st1, st2)
class _CmpFloat:
"""Store a float with a custom tolerance for equality comparisons"""
def __init__(self, value, tolerance=_DEFAULT_TOLERANCE):
self._value = value
self._tolerance = tolerance
def __eq__(self, that):
try:
return abs(self._value - that._value) < self._tolerance
except AttributeError:
# When comparing to a float
return abs(self._value - that) < self._tolerance
def __neq__(self, that):
return not (self == that)
def __str__(self):
return str(self._value)
def __repr__(self):
return repr(self._value)
@classmethod
def asCmp(cls, value, tolerance=_DEFAULT_TOLERANCE):
"""Give "value" a tolerance for equality comparison if it is a float"""
if isinstance(value, float):
return cls(value, tolerance)
else:
return value
[docs]def assert_properties_match(st1,
st2,
ignore=None,
tolerance=_DEFAULT_TOLERANCE,
msg=None):
"""
Do the Structure level properties of st1 and st2 match?
Skips properties listed in ignore. Uses pytest-like formatting for the dictionary
diff. tolerance is absolute
Output like::
custom_assertions.assert_properties_match(st1, st2)
E AssertionError: Omitting 2 identical items, use -vv to show
E Right contains 1 more item:
E {'i_user_test': 22}
or::
custom_assertions.assert_properties_match(st1, st2)
E AssertionError: Omitting 2 identical items, use -vv to show
E Differing items:
E {'i_user_test': 21} != {'i_user_test': 22}
"""
p1 = {k: _CmpFloat.asCmp(v, tolerance) for k, v in st1.property.items()}
p1[mm.MMCT_STEREO_STATUS_PROP] = mm.mmct_ct_get_stereo_status(st1)
if ignore:
p1 = {k: v for k, v in p1.items() if k not in ignore}
p2 = {k: _CmpFloat.asCmp(v, tolerance) for k, v in st2.property.items()}
p2[mm.MMCT_STEREO_STATUS_PROP] = mm.mmct_ct_get_stereo_status(st2)
if ignore:
p2 = {k: v for k, v in p2.items() if k not in ignore}
if p1 == p2:
return
explanation = _pytest.assertion.util._compare_eq_dict(p1, p2, True)
if msg:
explanation.insert(0, msg)
raise AssertionError('\n'.join(explanation))