"""
Custom assertions for "checking" structures.  Uses the Protein Report to
determine whether a structure has (for instance) bad steric clashes, or
unreasonable bond angles.
See test_structurecheck.py for examples.  It is also used extensively in psp.
copyright (c) Schrodinger, LLC. All rights reserved.
"""
from schrodinger.protein.analysis import Report as ProteinReport
__unittest = True
"""Keeps stack trace from including this module."""
[docs]class ProteinReportCheck:
    """
    This is a wrapper for ProteinReport which is used for unittesting
    and scientific testing.
    """
[docs]    def __init__(self,
                 ct,
                 steric_delta=None,
                 steric_distance=None,
                 bond_length_deviation=0.1,
                 bond_angle_deviation=10.0,
                 peptide_planarity_deviation=15.0,
                 sidechain_planarity_deviation=0.02,
                 improper_torsion_deviation=5.0):
        """
        :param ct: Structure to operate on.
        :param steric_delta: Provide a buffer of this size(A) between the sum
                of VdW radii of the two atoms and their distance (non-bonded
                atoms only) to determine what is a clash.
        :param steric_distance: Atoms closer than this distance
                will be counted as a clash atoms only) to determine what is a
                clash.
        :param bond_length_deviation: Bond length deviation of more than
                this will be counted as a bond length problem.
        :param bond_angle_deviation: Bond angle deviation of more than this
                (def) will be counted as a bond angle problem.
        :param peptide_planarity_deviation: Residues where the atoms that make
                up the peptide plane deviate from planarity more than this (deg)
                will be considered a problem.
        :param sidechain_planarity_deviation: Any aromatic residue where the
                atoms that make up the aromatic unit deviate more than this amount
                (A) from the plane will be considered a problem.
        :param improper_torsion_deviation: Any improper torsion which deviates
                from the canonical amount by more than this (rad) will be
                considered a problem.
        """
        # Set the structure and the minimum value to report, any value
        # worse than this will be roprted in reportSet and counted in
        # assertProteinHealth.
        self.ct = ct
        self.steric_delta = steric_delta
        self.steric_distance = steric_distance
        self.bond_length_deviation = bond_length_deviation
        self.bond_angle_deviation = bond_angle_deviation
        self.peptide_planarity_deviation = peptide_planarity_deviation
        self.sidechain_planarity_deviation = sidechain_planarity_deviation
        self.improper_torsion_deviation = improper_torsion_deviation 
[docs]    def assertProteinHealth(self, name, test_values):
        """
        Raise an assertion error if criteria are not met.
        :type name: str
        :param name: Title that is used in output(string)
        :type test_values: dict
        :param test_values: Key is the name of the test to perform and the
                value is the maximum number of exceptions allowed.
        :raise AssertionError: If the number of exceptions to the criterion
                exceeds the maximum allowed.
        """
        output = ""
        for test in test_values:
            value = test_values[test]
            title, header, lines = self.reportSet(test)
            if (len(lines) > value):
                output += ("Error for %s:\n" % name +
                           " Too many errors for %s (%d>%d)\n" %
                           (title, len(lines), value) + "%s\n%s" %
                           (header, "\n".join(lines)))
        if (output):
            raise AssertionError(output) 
[docs]    def reportSet(self, measurement):
        """
        Return a tuple containing name of the desired measurement of data (string), the
        a header providing the column titles (list of strings),
        and the values for that measurement (list of strings).  The values
        will be subset that is worse than the cutoffs provided in the
        initialized.
        Any property calculated by protein reports can be selected.
        :type measurement: str
        :param measurement: Measurement to check
        """
        if (measurement == 'STERIC CLASHES'):
            return self._reportSet(
                measurement,
                min_values=[None, None, self.steric_delta],
                max_values=[self.steric_distance, None, self.steric_delta])
        elif (measurement == 'BOND LENGTHS'):
            return self._reportSet(measurement,
                                   min_values=[self.bond_length_deviation])
        elif (measurement == 'BOND ANGLES'):
            return self._reportSet(measurement,
                                   min_values=[self.bond_angle_deviation])
        elif (measurement == 'BACKBONE DIHEDRALS'):
            return self._reportSet(measurement,
                                   in_values=[None, None, "disallowed"])
        elif (measurement == 'SIDECHAIN DIHEDRALS'):
            return self._reportSet(measurement,
                                   in_values=[None, None, "disallowed"])
        elif (measurement == 'PEPTIDE PLANARITY'):
            try:
                value = 180.0 - self.peptide_planarity_deviation
            except TypeError:
                value = None
            return self._reportSet(measurement, max_values=[value])
        elif (measurement == 'SIDECHAIN PLANARITY'):
            return self._reportSet(
                measurement, min_values=[self.sidechain_planarity_deviation])
        elif (measurement == 'IMPROPER TORSIONS'):
            return self._reportSet(measurement,
                                   min_values=[self.improper_torsion_deviation],
                                   not_in_values=['-'])
        elif (measurement == 'CHIRALITY'):
            return self._reportSet(measurement, not_in_values='L')
        else:
            return self._reportSet(measurement) 
    def _reportSet(self,
                   name,
                   max_values=None,
                   min_values=None,
                   in_values=None,
                   not_in_values=None):
        report = ProteinReport(self.ct, sets_to_run=[name])
        try:
            pr_set = report.data[0]
        except IndexError:
            raise RuntimeError("Set name %s not valid in ProteinReport" % name)
        point_lines = []
        for point in pr_set.points:
            # Skip over any points with subvalues that are too high or too low
            skip = False
            for ivalue, value in enumerate(point.values):
                if (not_in_values is not None and
                        ivalue < len(not_in_values) and
                        not_in_values[ivalue] is not None):
                    try:
                        if (not_in_values[ivalue] in value):
                            skip = True
                            break
                    except TypeError:
                        pass
                if ((max_values is not None and ivalue < len(max_values) and
                     max_values[ivalue] is not None and
                     value > max_values[ivalue]) or
                    (min_values is not None and ivalue < len(min_values) and
                     min_values[ivalue] is not None and
                     value < min_values[ivalue])):
                    skip = True
                    break
                if (in_values is not None and ivalue < len(in_values) and
                        in_values[ivalue] is not None):
                    try:
                        if (in_values[ivalue] not in value):
                            skip = True
                            break
                    except TypeError:
                        skip = True
                        break
            if (skip):
                continue
            # Actually write the formatted strings to output
            line = "%-30s" % point.descriptor
            for value in point.values:
                try:
                    line += "%8.2f" % value
                except TypeError:
                    line += "%10s" % str(value)
            point_lines.append(line)
        # Create the header
        header_line = "%-30s" % pr_set.fields[0]
        for header in pr_set.fields[1:]:
            header_line += "%-10s" % header
        return pr_set.title, header_line, point_lines 
[docs]def assert_no_major_problems(ct):
    """
    :raise AssertionError: If the connection table contains a
            SEVERE steric clash, bond length or angle deviation.
    """
    pr = ProteinReportCheck(ct,
                            steric_delta=1.0,
                            bond_length_deviation=0.1,
                            bond_angle_deviation=20.0)
    pr.assertProteinHealth(ct.title, {
        'STERIC CLASHES': 0,
        'BOND LENGTHS': 0,
        'BOND ANGLES': 0
    })