"""
Provides comparison functionality for Phase unit tests.
"""
import pytest
from schrodinger import structure
from schrodinger.infra import phase
FLOAT_TOL = 0.0001
[docs]def compare_base_features(bf1, bf2):
"""
Compares two PhpBaseFeature objects.
"""
assert bf1.getFeatureID() == bf2.getFeatureID()
assert bf1.getFeatureType() == bf2.getFeatureType()
assert bf1.getProjectedType() == bf2.getProjectedType()
assert bf1.getProjectedOnly() == bf2.getProjectedOnly()
assert bf1.getNumberOfAtoms() == bf2.getNumberOfAtoms()
n = bf1.getNumberOfAtoms()
for i in range(n):
assert bf1.getAtomIndex(i) == bf2.getAtomIndex(i)
[docs]def compare_hypo_add_cts(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares additional CTs of the provided PhpHypoAdaptor objects.
"""
nadd1 = hypo1.getAddCtCount()
nadd2 = hypo2.getAddCtCount()
mesg = f"Numbers of additional CTs differ: {nadd1} != {nadd2}"
assert nadd1 == nadd2, mesg
if nadd1 == 0:
return
cts1 = hypo1.getAddCts()
cts2 = hypo2.getAddCts()
for i in range(nadd1):
st1 = structure.Structure(cts1[i])
st2 = structure.Structure(cts2[i])
descr = f"Additional CT {i + 1}"
compare_structure_coordinates(st1, st2, tol, descr)
[docs]def compare_hypo_cnst(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares distance/angle/dihedral constrains in the provided
PhpHypoAdaptor objects.
"""
has_cnst1 = hypo1.hasCnst()
has_cnst2 = hypo2.hasCnst()
mesg = f"hasCnst() values differ: {has_cnst1} != {has_cnst2}"
assert has_cnst1 == has_cnst2, mesg
constraints1 = hypo1.getCnst().getConstraints()
constraints2 = hypo2.getCnst().getConstraints()
n1 = len(constraints1)
n2 = len(constraints2)
assert n1 == n2, f"Numbers of constraints differ: {n1} != {n2}"
for cnst1, cnst2 in zip(constraints1, constraints2):
indices1 = cnst1.getSiteIndices()
indices2 = cnst2.getSiteIndices()
mesg = f"Constraint site indices differ: {indices1} != {indices2}"
assert indices1 == indices2, mesg
types1 = cnst1.getSiteTypes()
types2 = cnst2.getSiteTypes()
mesg = f"Constraint site types differ: {types1} != {types2}"
assert types1 == types2, mesg
value1 = cnst1.getValue()
value2 = cnst2.getValue()
mesg = f"Constraint values differ: {value1} != {value2}"
assert value1 == pytest.approx(value2, abs=tol), mesg
tol1 = cnst1.getTol()
tol2 = cnst2.getTol()
mesg = f"Constraint tolerances differ: {tol1} != {tol2}"
assert tol1 == pytest.approx(tol2, abs=tol), mesg
[docs]def compare_hypo_masks(hypo1, hypo2):
"""
Compares site masks in the provided PhpHypoAdaptor objects.
"""
has_mask1 = hypo1.hasMask()
has_mask2 = hypo2.hasMask()
if has_mask1 != has_mask2:
return f"hasMask() values differ: {has_mask1} != {has_mask2}"
if not has_mask1:
return ""
mask1 = hypo1.getMask()
mask2 = hypo2.getMask()
n1 = mask1.numSites()
n2 = mask2.numSites()
if n1 != n2:
return f"Numbers of mask values differ: {n1} != {n2}"
v1 = mask1.getSiteMaskVector()
v2 = mask2.getSiteMaskVector()
if v1 != v2:
return f"Mask vectors differ: {v1} != {v2}"
[docs]def compare_hypo_qsar(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares the basic characteristics of QSAR models in the provided
PhpHypoAdaptor objects.
"""
has_qsar1 = hypo1.hasQsar()
has_qsar2 = hypo2.hasQsar()
mesg = f"hasQsar() values differ: {has_qsar1} != {has_qsar2}"
assert has_qsar1 == has_qsar2, mesg
if not has_qsar1:
return
qsar1 = hypo1.getQsar()
qsar2 = hypo2.getQsar()
project = phase.PhpProject()
style1 = project.getQsarModelStyle(qsar1)
style2 = project.getQsarModelStyle(qsar2)
mesg = f"QSAR model styles differ: {style1} != {style2}"
assert style1 == style2, mesg
spacing1 = project.getQsarModelGridSpacing(qsar1)
spacing2 = project.getQsarModelGridSpacing(qsar2)
mesg = f"QSAR model grid spacing differs: {spacing1} != {spacing2}"
assert spacing1 == pytest.approx(spacing2, abs=tol), mesg
nf1 = project.getQsarModelFactorCount(qsar1)
nf2 = project.getQsarModelFactorCount(qsar2)
assert nf1 == nf2, f"QSAR model factor counts differ: {nf1} != {nf2}"
for i in range(1, nf1 + 1):
sd1 = project.getQsarModelRegSD(qsar1, i)
sd2 = project.getQsarModelRegSD(qsar2, i)
mesg = f"QSAR SD({i}) values differ: {sd1} != {sd2}"
assert sd1 == pytest.approx(sd2, abs=tol), mesg
rsqr1 = project.getQsarModelRegRsqr(qsar1, i)
rsqr2 = project.getQsarModelRegRsqr(qsar2, i)
mesg = f"QSAR R^2({i}) values differ: {rsqr1} != {rsqr2}"
assert rsqr1 == pytest.approx(rsqr2, abs=tol), mesg
[docs]def compare_hypo_rad(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares feature radii in the provided PhpHypoAdaptor objects.
"""
has_rad1 = hypo1.hasRad()
has_rad2 = hypo2.hasRad()
mesg = f"hasRad() values differ: {has_rad1} != {has_rad2}"
assert has_rad1 == has_rad2, mesg
if not has_rad1:
return
rad1 = hypo1.getRad()
rad2 = hypo2.getRad()
for c in "ADHNPRXYZ":
r1 = rad1.getDataValue(c)
r2 = rad2.getDataValue(c)
mesg = f"Feature radii differ: {r1} != {r2}"
assert r1 == pytest.approx(r2, abs=tol), mesg
[docs]def compare_hypo_ref(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares coordinates of the reference CTs in the PhpHypoAdaptor objects.
"""
st1 = structure.Structure(hypo1.getRefCt())
st2 = structure.Structure(hypo2.getRefCt())
compare_structure_coordinates(st1, st2, tol, "Reference CT")
[docs]def compare_hypo_rules(hypo1, hypo2):
"""
Compares feature matching rules in the provided PhpHypoAdaptor objects.
"""
has_rules1 = hypo1.hasRules()
has_rules2 = hypo2.hasRules()
mesg = f"hasRules() values differ: {has_rules1} != {has_rules2}"
assert has_rules1 == has_rules2, mesg
if not has_rules1:
return
rules1 = hypo1.getRules()
rules2 = hypo2.getRules()
n1 = rules1.numSites()
n2 = rules2.numSites()
assert n1 == n2, f"Numbers of feature matching rules differ: {n1} != {n2}"
permitted1 = rules1.getPermittedFeatures()
permitted2 = rules2.getPermittedFeatures()
mesg = f"Permitted features differ: {permitted1} != {permitted2}"
assert permitted1 == permitted2, mesg
prohibited1 = rules1.getProhibitedFeatures()
prohibited2 = rules2.getProhibitedFeatures()
mesg = f"Prohibited features differ: {prohibited1} != {prohibited2}"
assert prohibited1 == prohibited2, mesg
[docs]def compare_hypo_sites(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares the sites in the provided PhpHypoAdaptor objects.
"""
sites1 = hypo1.getHypoSites()
sites2 = hypo2.getHypoSites()
len1 = len(sites1)
len2 = len(sites2)
assert len1 == len2, f"Numbers of sites differ: {len1} != {len2}"
for site1, site2 in zip(sites1, sites2):
compare_sites(site1, site2)
[docs]def compare_hypo_tol(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares positional tolerances in the provided PhpHypoAdaptor objects.
objects.
"""
has_tol1 = hypo1.hasTol()
has_tol2 = hypo2.hasTol()
mesg = f"hasTol() values differ: {has_tol1} != {has_tol2}"
assert has_tol1 == has_tol2, mesg
if not has_tol1:
return
tol1 = hypo1.getTol()
tol2 = hypo2.getTol()
n1 = tol1.numSites()
n2 = tol2.numSites()
assert n1 == n2, f"Numbers of tolerances differ: {n1} != {n2}"
for i in range(n1):
value1 = tol1.getDeltaXYZ(i)
value2 = tol2.getDeltaXYZ(i)
mesg = f"Positional tolerances differ: {value1} != {value2}"
assert value1 == pytest.approx(value2, abs=tol), mesg
[docs]def compare_hypo_vol(hypo1, hypo2, tol=FLOAT_TOL, included=False):
"""
Compares excluded or included volumes in the provided PhpHypoAdaptor
objects.
"""
vtype = "Included" if included else "Excluded"
if included:
has_ivol1 = hypo1.hasIvol()
has_ivol2 = hypo2.hasIvol()
mesg = f"hasIvol() values differ: {has_ivol1} != {has_ivol2}"
assert has_ivol1 == has_ivol2, mesg
if not has_ivol1:
return
vol1 = hypo1.getIvol()
vol2 = hypo2.getIvol()
else:
has_xvol1 = hypo1.hasXvol()
has_xvol2 = hypo2.hasXvol()
mesg = f"hasXvol() values differ: {has_xvol1} != {has_xvol2}"
assert has_xvol1 == has_xvol2, mesg
if not has_xvol1:
return
vol1 = hypo1.getXvol()
vol2 = hypo2.getXvol()
num1 = vol1.numSpheres()
num2 = vol2.numSpheres()
assert num1 == num2, f"{vtype} volume counts differ: {num1} != {num2}"
for i in range(num1):
sphere1 = vol1.getSphere(i)
sphere2 = vol2.getSphere(i)
xyz1 = (sphere1.x, sphere1.y, sphere1.z)
xyz2 = (sphere2.x, sphere2.y, sphere2.z)
for value1, value2 in zip(xyz1, xyz2):
mesg = f"{vtype} volume coordinates differ: {xyz1} != {xyz2}"
assert value1 == pytest.approx(value2, abs=tol), mesg
r1 = sphere1.r
r2 = sphere2.r
mesg = f"{vtype} volume radii differ: {r1} != {r2}"
assert r1 == pytest.approx(r2, abs=tol), mesg
[docs]def compare_hypos(hypo1, hypo2, tol=FLOAT_TOL):
"""
Compares the provided PhpHypoAdaptor objects.
"""
hypo1_id = hypo1.getHypoID()
hypo2_id = hypo2.getHypoID()
mesg = f"Hypothesis IDs differ: {hypo1_id} != {hypo2_id}"
assert hypo1_id == hypo2_id, mesg
assert hypo1.getFd() == hypo2.getFd(), "Feature definitions differ"
compare_hypo_sites(hypo1, hypo2, tol)
compare_hypo_add_cts(hypo1, hypo2, tol)
compare_hypo_cnst(hypo1, hypo2)
compare_hypo_vol(hypo1, hypo2, included=True)
compare_hypo_masks(hypo1, hypo2)
compare_hypo_qsar(hypo1, hypo2, tol)
compare_hypo_rad(hypo1, hypo2, tol)
compare_hypo_ref(hypo1, hypo2, tol)
compare_hypo_rules(hypo1, hypo2)
compare_hypo_tol(hypo1, hypo2, tol)
compare_hypo_vol(hypo1, hypo2, included=False)
[docs]def compare_sites(site1, site2, tol=FLOAT_TOL, compare_site_status=True):
"""
Compares two PhpSite objects.
"""
n1 = site1.getSiteNumber()
n2 = site2.getSiteNumber()
assert n1 == n2, f"Site numbers differ: {n1} != {n2}"
type1 = site1.getSiteType()
type2 = site2.getSiteType()
assert type1 == type2, f"Site types differ: \"{type1}\" != \"{type2}\""
sname = site1.getSiteName(0)
frag1 = site1.getFragType()
frag2 = site2.getFragType()
mesg = f"{sname} fragment types differ: \"{frag1}\" != \"{frag2}\""
assert frag1 == frag2, mesg
mask1 = site1.getMaskValue()
mask2 = site2.getMaskValue()
mesg = f"{sname} mask values differ: {mask1} != {mask2}"
assert mask1 == mask2, mesg
perm1 = site1.getPermitted()
perm2 = site2.getPermitted()
mesg = f"{sname} permitted types differ: \"{perm1}\" != \"{perm2}\""
assert perm1 == perm2, mesg
proh1 = site1.getProhibited()
proh2 = site2.getProhibited()
mesg = f"{sname} prohibited types differ: \"{proh1}\" != \"{proh2}\""
assert proh1 == proh2, mesg
ponly1 = site1.getProjectedOnly()
ponly2 = site2.getProjectedOnly()
mesg = f"{sname} projected only values differ {ponly1} != {ponly2}"
assert ponly1 == ponly2, mesg
rad1 = site1.getRad()
rad2 = site2.getRad()
mesg = f"{sname} radii differ: {rad1} != {rad2}"
assert rad1 == pytest.approx(rad2, abs=tol), mesg
# Surface-accessibility isn't relevant for all applications, so it
# doesn't always need to be compared.
if compare_site_status:
ss1 = site1.getSiteStatus()
ss2 = site2.getSiteStatus()
mesg = f"{sname} surface accessibilities differ: {ss1} != {ss2}"
assert ss1 == ss2, mesg
tol1 = site1.getTol()
tol2 = site2.getTol()
mesg = f"{sname} tolerances differ: {tol1} != {tol2}"
assert tol1 == pytest.approx(tol2, abs=tol), mesg
xyz1 = site1.getCoordinates()
xyz2 = site2.getCoordinates()
for value1, value2 in zip(xyz1, xyz2):
mesg = f"{sname} base coordinates differ: {xyz1} != {xyz2}"
assert value1 == pytest.approx(value2, abs=tol), mesg
proj1 = site1.getProjCoords()
proj2 = site2.getProjCoords()
len1 = len(proj1)
len2 = len(proj2)
mesg = f"{sname} projected point counts differ: {len1} != {len2}"
assert len1 == len2, mesg
for xyz1, xyz2 in zip(proj1, proj2):
for value1, value2 in zip(xyz1, xyz2):
mesg = f"{sname} projected coordinates differ: {xyz1} != {xyz2}"
assert value1 == pytest.approx(value2, abs=tol), mesg
[docs]def compare_structure_coordinates(st1, st2, tol=FLOAT_TOL, descr=""):
"""
Compares the coordinates of two structure.Structure objects. descr is
a description of the structure, e.g., "Conformer 7", "Reference CT",
"Additional CT 3".
"""
atom_numbers = st1.getAtomIndices()
compare_structure_coordinates_subset(st1, st2, atom_numbers, tol, descr)
[docs]def compare_structure_coordinates_subset(st1,
st2,
atom_numbers,
tol=FLOAT_TOL,
descr=""):
"""
Compares coordinates for a subset of atoms. Numbering starts at 1.
"""
if not descr.endswith(" "):
descr += " "
n1 = st1.atom_total
n2 = st2.atom_total
assert n1 == n2, f"{descr}atom totals differ: {n1} != {n2}"
for i in atom_numbers:
xyz1 = st1.atom[i].xyz
xyz2 = st2.atom[i].xyz
for value1, value2 in zip(xyz1, xyz2):
mesg = f"{descr}atom {i} coordinates differ: {xyz1} != {xyz2}"
assert value1 == pytest.approx(value2, abs=tol), mesg