"""
Provides comparison functionality for Phase unit tests.
"""
import pytest
from schrodinger import structure
from schrodinger.infra import phase
FLOAT_TOL = 0.0001
[docs]def compare_base_features(bf1, bf2):
    """
    Compares two PhpBaseFeature objects.
    """
    assert bf1.getFeatureID() == bf2.getFeatureID()
    assert bf1.getFeatureType() == bf2.getFeatureType()
    assert bf1.getProjectedType() == bf2.getProjectedType()
    assert bf1.getProjectedOnly() == bf2.getProjectedOnly()
    assert bf1.getNumberOfAtoms() == bf2.getNumberOfAtoms()
    n = bf1.getNumberOfAtoms()
    for i in range(n):
        assert bf1.getAtomIndex(i) == bf2.getAtomIndex(i) 
[docs]def compare_hypo_add_cts(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares additional CTs of the provided PhpHypoAdaptor objects.
    """
    nadd1 = hypo1.getAddCtCount()
    nadd2 = hypo2.getAddCtCount()
    mesg = f"Numbers of additional CTs differ: {nadd1} != {nadd2}"
    assert nadd1 == nadd2, mesg
    if nadd1 == 0:
        return
    cts1 = hypo1.getAddCts()
    cts2 = hypo2.getAddCts()
    for i in range(nadd1):
        st1 = structure.Structure(cts1[i])
        st2 = structure.Structure(cts2[i])
        descr = f"Additional CT {i + 1}"
        compare_structure_coordinates(st1, st2, tol, descr) 
[docs]def compare_hypo_cnst(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares distance/angle/dihedral constrains in the provided
    PhpHypoAdaptor objects.
    """
    has_cnst1 = hypo1.hasCnst()
    has_cnst2 = hypo2.hasCnst()
    mesg = f"hasCnst() values differ: {has_cnst1} != {has_cnst2}"
    assert has_cnst1 == has_cnst2, mesg
    constraints1 = hypo1.getCnst().getConstraints()
    constraints2 = hypo2.getCnst().getConstraints()
    n1 = len(constraints1)
    n2 = len(constraints2)
    assert n1 == n2, f"Numbers of constraints differ: {n1} != {n2}"
    for cnst1, cnst2 in zip(constraints1, constraints2):
        indices1 = cnst1.getSiteIndices()
        indices2 = cnst2.getSiteIndices()
        mesg = f"Constraint site indices differ: {indices1} != {indices2}"
        assert indices1 == indices2, mesg
        types1 = cnst1.getSiteTypes()
        types2 = cnst2.getSiteTypes()
        mesg = f"Constraint site types differ: {types1} != {types2}"
        assert types1 == types2, mesg
        value1 = cnst1.getValue()
        value2 = cnst2.getValue()
        mesg = f"Constraint values differ: {value1} != {value2}"
        assert value1 == pytest.approx(value2, abs=tol), mesg
        tol1 = cnst1.getTol()
        tol2 = cnst2.getTol()
        mesg = f"Constraint tolerances differ: {tol1} != {tol2}"
        assert tol1 == pytest.approx(tol2, abs=tol), mesg 
[docs]def compare_hypo_masks(hypo1, hypo2):
    """
    Compares site masks in the provided PhpHypoAdaptor objects.
    """
    has_mask1 = hypo1.hasMask()
    has_mask2 = hypo2.hasMask()
    if has_mask1 != has_mask2:
        return f"hasMask() values differ: {has_mask1} != {has_mask2}"
    if not has_mask1:
        return ""
    mask1 = hypo1.getMask()
    mask2 = hypo2.getMask()
    n1 = mask1.numSites()
    n2 = mask2.numSites()
    if n1 != n2:
        return f"Numbers of mask values differ: {n1} != {n2}"
    v1 = mask1.getSiteMaskVector()
    v2 = mask2.getSiteMaskVector()
    if v1 != v2:
        return f"Mask vectors differ: {v1} != {v2}" 
[docs]def compare_hypo_qsar(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares the basic characteristics of QSAR models in the provided
    PhpHypoAdaptor objects.
    """
    has_qsar1 = hypo1.hasQsar()
    has_qsar2 = hypo2.hasQsar()
    mesg = f"hasQsar() values differ: {has_qsar1} != {has_qsar2}"
    assert has_qsar1 == has_qsar2, mesg
    if not has_qsar1:
        return
    qsar1 = hypo1.getQsar()
    qsar2 = hypo2.getQsar()
    project = phase.PhpProject()
    style1 = project.getQsarModelStyle(qsar1)
    style2 = project.getQsarModelStyle(qsar2)
    mesg = f"QSAR model styles differ: {style1} != {style2}"
    assert style1 == style2, mesg
    spacing1 = project.getQsarModelGridSpacing(qsar1)
    spacing2 = project.getQsarModelGridSpacing(qsar2)
    mesg = f"QSAR model grid spacing differs: {spacing1} != {spacing2}"
    assert spacing1 == pytest.approx(spacing2, abs=tol), mesg
    nf1 = project.getQsarModelFactorCount(qsar1)
    nf2 = project.getQsarModelFactorCount(qsar2)
    assert nf1 == nf2, f"QSAR model factor counts differ: {nf1} != {nf2}"
    for i in range(1, nf1 + 1):
        sd1 = project.getQsarModelRegSD(qsar1, i)
        sd2 = project.getQsarModelRegSD(qsar2, i)
        mesg = f"QSAR SD({i}) values differ: {sd1} != {sd2}"
        assert sd1 == pytest.approx(sd2, abs=tol), mesg
        rsqr1 = project.getQsarModelRegRsqr(qsar1, i)
        rsqr2 = project.getQsarModelRegRsqr(qsar2, i)
        mesg = f"QSAR R^2({i}) values differ: {rsqr1} != {rsqr2}"
        assert rsqr1 == pytest.approx(rsqr2, abs=tol), mesg 
[docs]def compare_hypo_rad(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares feature radii in the provided PhpHypoAdaptor objects.
    """
    has_rad1 = hypo1.hasRad()
    has_rad2 = hypo2.hasRad()
    mesg = f"hasRad() values differ: {has_rad1} != {has_rad2}"
    assert has_rad1 == has_rad2, mesg
    if not has_rad1:
        return
    rad1 = hypo1.getRad()
    rad2 = hypo2.getRad()
    for c in "ADHNPRXYZ":
        r1 = rad1.getDataValue(c)
        r2 = rad2.getDataValue(c)
        mesg = f"Feature radii differ: {r1} != {r2}"
        assert r1 == pytest.approx(r2, abs=tol), mesg 
[docs]def compare_hypo_ref(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares coordinates of the reference CTs in the PhpHypoAdaptor objects.
    """
    st1 = structure.Structure(hypo1.getRefCt())
    st2 = structure.Structure(hypo2.getRefCt())
    compare_structure_coordinates(st1, st2, tol, "Reference CT") 
[docs]def compare_hypo_rules(hypo1, hypo2):
    """
    Compares feature matching rules in the provided PhpHypoAdaptor objects.
    """
    has_rules1 = hypo1.hasRules()
    has_rules2 = hypo2.hasRules()
    mesg = f"hasRules() values differ: {has_rules1} != {has_rules2}"
    assert has_rules1 == has_rules2, mesg
    if not has_rules1:
        return
    rules1 = hypo1.getRules()
    rules2 = hypo2.getRules()
    n1 = rules1.numSites()
    n2 = rules2.numSites()
    assert n1 == n2, f"Numbers of feature matching rules differ: {n1} != {n2}"
    permitted1 = rules1.getPermittedFeatures()
    permitted2 = rules2.getPermittedFeatures()
    mesg = f"Permitted features differ: {permitted1} != {permitted2}"
    assert permitted1 == permitted2, mesg
    prohibited1 = rules1.getProhibitedFeatures()
    prohibited2 = rules2.getProhibitedFeatures()
    mesg = f"Prohibited features differ: {prohibited1} != {prohibited2}"
    assert prohibited1 == prohibited2, mesg 
[docs]def compare_hypo_sites(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares the sites in the provided PhpHypoAdaptor objects.
    """
    sites1 = hypo1.getHypoSites()
    sites2 = hypo2.getHypoSites()
    len1 = len(sites1)
    len2 = len(sites2)
    assert len1 == len2, f"Numbers of sites differ: {len1} != {len2}"
    for site1, site2 in zip(sites1, sites2):
        compare_sites(site1, site2) 
[docs]def compare_hypo_tol(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares positional tolerances in the provided PhpHypoAdaptor objects.
    objects.
    """
    has_tol1 = hypo1.hasTol()
    has_tol2 = hypo2.hasTol()
    mesg = f"hasTol() values differ: {has_tol1} != {has_tol2}"
    assert has_tol1 == has_tol2, mesg
    if not has_tol1:
        return
    tol1 = hypo1.getTol()
    tol2 = hypo2.getTol()
    n1 = tol1.numSites()
    n2 = tol2.numSites()
    assert n1 == n2, f"Numbers of tolerances differ: {n1} != {n2}"
    for i in range(n1):
        value1 = tol1.getDeltaXYZ(i)
        value2 = tol2.getDeltaXYZ(i)
        mesg = f"Positional tolerances differ: {value1} != {value2}"
        assert value1 == pytest.approx(value2, abs=tol), mesg 
[docs]def compare_hypo_vol(hypo1, hypo2, tol=FLOAT_TOL, included=False):
    """
    Compares excluded or included volumes in the provided PhpHypoAdaptor
    objects.
    """
    vtype = "Included" if included else "Excluded"
    if included:
        has_ivol1 = hypo1.hasIvol()
        has_ivol2 = hypo2.hasIvol()
        mesg = f"hasIvol() values differ: {has_ivol1} != {has_ivol2}"
        assert has_ivol1 == has_ivol2, mesg
        if not has_ivol1:
            return
        vol1 = hypo1.getIvol()
        vol2 = hypo2.getIvol()
    else:
        has_xvol1 = hypo1.hasXvol()
        has_xvol2 = hypo2.hasXvol()
        mesg = f"hasXvol() values differ: {has_xvol1} != {has_xvol2}"
        assert has_xvol1 == has_xvol2, mesg
        if not has_xvol1:
            return
        vol1 = hypo1.getXvol()
        vol2 = hypo2.getXvol()
    num1 = vol1.numSpheres()
    num2 = vol2.numSpheres()
    assert num1 == num2, f"{vtype} volume counts differ: {num1} != {num2}"
    for i in range(num1):
        sphere1 = vol1.getSphere(i)
        sphere2 = vol2.getSphere(i)
        xyz1 = (sphere1.x, sphere1.y, sphere1.z)
        xyz2 = (sphere2.x, sphere2.y, sphere2.z)
        for value1, value2 in zip(xyz1, xyz2):
            mesg = f"{vtype} volume coordinates differ: {xyz1} != {xyz2}"
            assert value1 == pytest.approx(value2, abs=tol), mesg
        r1 = sphere1.r
        r2 = sphere2.r
        mesg = f"{vtype} volume radii differ: {r1} != {r2}"
        assert r1 == pytest.approx(r2, abs=tol), mesg 
[docs]def compare_hypos(hypo1, hypo2, tol=FLOAT_TOL):
    """
    Compares the provided PhpHypoAdaptor objects.
    """
    hypo1_id = hypo1.getHypoID()
    hypo2_id = hypo2.getHypoID()
    mesg = f"Hypothesis IDs differ: {hypo1_id} != {hypo2_id}"
    assert hypo1_id == hypo2_id, mesg
    assert hypo1.getFd() == hypo2.getFd(), "Feature definitions differ"
    compare_hypo_sites(hypo1, hypo2, tol)
    compare_hypo_add_cts(hypo1, hypo2, tol)
    compare_hypo_cnst(hypo1, hypo2)
    compare_hypo_vol(hypo1, hypo2, included=True)
    compare_hypo_masks(hypo1, hypo2)
    compare_hypo_qsar(hypo1, hypo2, tol)
    compare_hypo_rad(hypo1, hypo2, tol)
    compare_hypo_ref(hypo1, hypo2, tol)
    compare_hypo_rules(hypo1, hypo2)
    compare_hypo_tol(hypo1, hypo2, tol)
    compare_hypo_vol(hypo1, hypo2, included=False) 
[docs]def compare_sites(site1, site2, tol=FLOAT_TOL, compare_site_status=True):
    """
    Compares two PhpSite objects.
    """
    n1 = site1.getSiteNumber()
    n2 = site2.getSiteNumber()
    assert n1 == n2, f"Site numbers differ: {n1} != {n2}"
    type1 = site1.getSiteType()
    type2 = site2.getSiteType()
    assert type1 == type2, f"Site types differ: \"{type1}\" != \"{type2}\""
    sname = site1.getSiteName(0)
    frag1 = site1.getFragType()
    frag2 = site2.getFragType()
    mesg = f"{sname} fragment types differ: \"{frag1}\" != \"{frag2}\""
    assert frag1 == frag2, mesg
    mask1 = site1.getMaskValue()
    mask2 = site2.getMaskValue()
    mesg = f"{sname} mask values differ: {mask1} != {mask2}"
    assert mask1 == mask2, mesg
    perm1 = site1.getPermitted()
    perm2 = site2.getPermitted()
    mesg = f"{sname} permitted types differ: \"{perm1}\" != \"{perm2}\""
    assert perm1 == perm2, mesg
    proh1 = site1.getProhibited()
    proh2 = site2.getProhibited()
    mesg = f"{sname} prohibited types differ: \"{proh1}\" != \"{proh2}\""
    assert proh1 == proh2, mesg
    ponly1 = site1.getProjectedOnly()
    ponly2 = site2.getProjectedOnly()
    mesg = f"{sname} projected only values differ {ponly1} != {ponly2}"
    assert ponly1 == ponly2, mesg
    rad1 = site1.getRad()
    rad2 = site2.getRad()
    mesg = f"{sname} radii differ: {rad1} != {rad2}"
    assert rad1 == pytest.approx(rad2, abs=tol), mesg
    # Surface-accessibility isn't relevant for all applications, so it
    # doesn't always need to be compared.
    if compare_site_status:
        ss1 = site1.getSiteStatus()
        ss2 = site2.getSiteStatus()
        mesg = f"{sname} surface accessibilities differ: {ss1} != {ss2}"
        assert ss1 == ss2, mesg
    tol1 = site1.getTol()
    tol2 = site2.getTol()
    mesg = f"{sname} tolerances differ: {tol1} != {tol2}"
    assert tol1 == pytest.approx(tol2, abs=tol), mesg
    xyz1 = site1.getCoordinates()
    xyz2 = site2.getCoordinates()
    for value1, value2 in zip(xyz1, xyz2):
        mesg = f"{sname} base coordinates differ: {xyz1} != {xyz2}"
        assert value1 == pytest.approx(value2, abs=tol), mesg
    proj1 = site1.getProjCoords()
    proj2 = site2.getProjCoords()
    len1 = len(proj1)
    len2 = len(proj2)
    mesg = f"{sname} projected point counts differ: {len1} != {len2}"
    assert len1 == len2, mesg
    for xyz1, xyz2 in zip(proj1, proj2):
        for value1, value2 in zip(xyz1, xyz2):
            mesg = f"{sname} projected coordinates differ: {xyz1} != {xyz2}"
            assert value1 == pytest.approx(value2, abs=tol), mesg 
[docs]def compare_structure_coordinates(st1, st2, tol=FLOAT_TOL, descr=""):
    """
    Compares the coordinates of two structure.Structure objects. descr is
    a description of the structure, e.g., "Conformer 7", "Reference CT",
    "Additional CT 3".
    """
    atom_numbers = st1.getAtomIndices()
    compare_structure_coordinates_subset(st1, st2, atom_numbers, tol, descr) 
[docs]def compare_structure_coordinates_subset(st1,
                                         st2,
                                         atom_numbers,
                                         tol=FLOAT_TOL,
                                         descr=""):
    """
    Compares coordinates for a subset of atoms. Numbering starts at 1.
    """
    if not descr.endswith(" "):
        descr += " "
    n1 = st1.atom_total
    n2 = st2.atom_total
    assert n1 == n2, f"{descr}atom totals differ: {n1} != {n2}"
    for i in atom_numbers:
        xyz1 = st1.atom[i].xyz
        xyz2 = st2.atom[i].xyz
        for value1, value2 in zip(xyz1, xyz2):
            mesg = f"{descr}atom {i} coordinates differ: {xyz1} != {xyz2}"
            assert value1 == pytest.approx(value2, abs=tol), mesg