"""
Clusters actives and hypotheses into possible binding modes. Actives are
represented by bit strings encoding the hypotheses they match, and
hypotheses are represented by bit strings encoding the actives they match.
Tanimoto similarities between bit strings are computed, and hierarchical,
agglomerative clustering is performed on both actives and hypotheses. The
presence of consistent groupings of actives and hypotheses may indicate the
existence of multiple binding modes.
For example, if there are 10 hypotheses and 8 actives, an idealized clustered
bit matrix for 2 clusters might look like this::
                        Actives                Order
                  H  1 1 1 1 0 0 0 0             7
                  y  1 1 1 1 0 0 0 0             1
                  p  1 1 1 1 0 0 0 0             4
                  o  1 1 1 1 0 0 0 0             0
                  t  1 1 1 1 0 0 0 0             9
                  h  1 1 1 1 0 0 0 0   Cut 0 --- 2
                  e  0 0 0 0 1 1 1 1             6
                  s  0 0 0 0 1 1 1 1             5
                  e  0 0 0 0 1 1 1 1             8
                  s  0 0 0 0 1 1 1 1   Cut 1 --- 3
              Order  3 5 0 2 7 1 5 4
                           |       |
                           |       |
                         Cut 0   Cut 1
Example Usage::
    hypos = hypothesis.extract_hypotheses(phypo_path)
    results = hbm.calculate_binding_modes(hypos, 2)
    cluster_matrix, active_IDs, hypo_IDs, actives_cut, hypo_cut = results
Copyright Schrodinger LLC, All Rights Reserved.
"""
from collections import OrderedDict
from past.utils import old_div
import numpy
from schrodinger.infra import phase
from schrodinger.structure import Structure
[docs]def calculate_binding_modes(hypotheses, num_modes):
    """
    Clusters actives and hypotheses into possible binding modes. Returns:
    - clutered bit matrix for actives (columns) and hypotheses (rows)
    - active IDs in column order
    - hypothesis IDs in row order
    - 0-based cluster cutoff indices for actives clusters
    - 0-based cluster cutoff indices for hypotheses clusters
    :param hypotheses: list of Phase hypotheses
    :type hypotheses: list of `hypothesis.PhaseHypothesis`
    :param num_modes: proposed number of binding modes (i.e. clusters)
    :type num_modes: int
    :return: cluster bit matrix (number of hypos x number of actives),
             active IDs, hypotheis IDs, active cut indices, hypo cut indices
    :type: tuple, tuple, tuple, tuple, tuple
    """
    # Build list of hypothesis IDs, active IDs, and bit matrix
    actives_bit_dict = _actives_bit_dict(hypotheses)
    bit_matrix = numpy.array(list(actives_bit_dict.values()), dtype=bool)
    validated, msg = _validate_bit_matrix(bit_matrix, num_modes)
    if not validated:
        raise phase.PhpException(msg)
    # Cluster based on distance matrices
    active_order, active_cuts = _perform_clustering(bit_matrix, num_modes)
    hypo_bit_matrix = numpy.transpose(bit_matrix)
    hypo_order, hypo_cuts = _perform_clustering(hypo_bit_matrix, num_modes)
    # Assemble clustered bit matrix that reflects the clustered order
    binding_matrix = numpy.zeros(hypo_bit_matrix.shape, dtype=int)
    for i, hypo_index in enumerate(hypo_order):
        for j, active_index in enumerate(active_order):
            binding_matrix[i][j] = bit_matrix[active_index][hypo_index]
    # Return sorted IDs for row/column titles
    active_IDs = tuple(list(actives_bit_dict)[i] for i in active_order)
    hypo_IDs = tuple(hypotheses[i].getHypoID() for i in hypo_order)
    # Return as tuple
    binding_tuple = tuple(tuple(row) for row in binding_matrix)
    return binding_tuple, active_IDs, hypo_IDs, active_cuts, hypo_cuts 
def _get_active_IDs(hypothesis):
    """
    Extracts all PHASE_LIGAND_NAME properties from the reference ligand or any
    actives in the current hypothesis.PhaseHypothesis object.
    :param hypothesis: hypothesis from which to extract active IDs
    :type hypothesis: `hypothesis.PhaseHypothesis`
    :return: list of ligand names for stored actives (expected mol_%d)
    :rtype: list of str
    """
    if not hypothesis.hasRefCt():
        hypoID = hypothesis.getHypoID()
        raise phase.PhpException("%s does not have reference ligand" % hypoID)
    actives = []
    # Reference Ligand
    ref_st = Structure(hypothesis.getRefCt())
    actives.append(ref_st.property[phase.PHASE_LIGAND_NAME])
    # Additional Cts
    for ct in hypothesis.getAddCts():
        st = Structure(ct)
        if st.property[phase.PHASE_HYPO_ROLE] == phase.ROLE_ACTIVE:
            actives.append(st.property[phase.PHASE_LIGAND_NAME])
    return actives
def _actives_bit_dict(hypotheses):
    """
    Creates bit matrix dictionary from set of hypotheses, where each key is an
    active ID, and corresponding values are numpy arrays indicating if that
    hypothesis (array index) includes the given active (1 it true, 0 otherwise).
    :param hypothesis: list of hypotheses with actives
    :type hypothesis: list of `hypothesis.PhaseHypothesis`
    :return: bit matrix dictionary where values are numpy array of 1/0
    :rtype: dict
    """
    empty_row = numpy.zeros(len(hypotheses), dtype=int)
    actives_dict = OrderedDict()
    for i, hypo in enumerate(hypotheses):
        for active_ID in _get_active_IDs(hypo):
            # Add entry if this active has not been encountered before
            if active_ID not in actives_dict:
                actives_dict[active_ID] = numpy.copy(empty_row)
            # Flip the bit for this active on the current hypothesis index
            actives_dict[active_ID][i] = 1
    return actives_dict
def _validate_bit_matrix(bit_matrix, num_modes):
    """
    Validates the size and composition of the bit matrix based on the number
    of proposed binding modes.
    :param bit_matrix: Bit matrix indicating hypothesis/active intersections
    :type bit_matrix: `numpy.array`
    :param num_modes: proposed number of binding modes
    :type num_modes: int
    :return: if validate bit matrix, error message
    :rtype: bool, str
    """
    # Verify dimenions against number of proposed binding modes
    if num_modes < 2:
        msg = "The number of proposed binding modes must be 2 or greater"
        return False, msg
    if bit_matrix.shape[0] < num_modes or bit_matrix.shape[1] < num_modes:
        msg = "The number of proposed binding modes must be less than the " + \
              " number of hypotheses and the number of actives."
        return False, msg
    # Verify all bits are not turned on, otherwise cannot do clustering
    if numpy.sum(bit_matrix) == numpy.prod(bit_matrix.shape):
        msg = "All actives match all hypotheses; cannot cluster binding modes"
        return False, msg
    # Assure that all hypotheses have at least one included active, which
    # should be functionally equivalent to checking phase_hypothesis was run
    for i in range(bit_matrix.shape[1]):
        if sum(bit_matrix[:, i]) == 0:
            msg = "There are hypotheses which contains no actives to cluster"
            return False, msg
    return True, ""
def _tanimoto_coefficient(bitrow_i, bitrow_j):
    """
    Computes Tanimoto coefficient between two bit arrays.
    :param bitrow_i: vector of bits
    :type bitrow_i: list of ints
    :param bitrow_j: vector of bits
    :param bitrow_j: list of ints
    :return: tanimoto coefficient
    :rtype: float
    """
    intersection = float(sum(a and b for a, b in zip(bitrow_i, bitrow_j)))
    union = float(sum(a or b for a, b in zip(bitrow_i, bitrow_j)))
    return old_div(intersection, union)
def _distance_matrix(bit_matrix):
    """
    Computes distance matrix to use for clustering, where values are given as
    (1 - Tanimoto coefficient_ij) between rows i and j of the matrix.
    :param bit_matrix: Bit matrix indicating row/column intersections
    :type bit_matrix: `numpy.array`
    :return: 2D numpy array of bit distances between all row pairs
    :rtype: `numpy.array`
    """
    distance_matrix = numpy.zeros([bit_matrix.shape[0]] * 2)
    for i in range(1, distance_matrix.shape[0]):
        rowi = bit_matrix[i]
        for j in range(i):
            rowj = bit_matrix[j]
            try:
                distance = 1.0 - _tanimoto_coefficient(rowi, rowj)
            except ZeroDivisionError as e:
                distance = 0.0
            distance_matrix[i][j] = distance
    return distance_matrix
def _perform_clustering(bit_matrix, num_modes):
    """
    Performs clustering using the PhpHiCluster class on a given bit matrix
    for an expected number of clustering modes.
    :param bit_matrix: Bit matrix indicating hypothesis/active intersections
    :type bit_matrix: `numpy.array`
    :param num_modes: proposed number of binding modes
    :type num_modes: int
    :return: indices sorted by clustering order, indices for cutoff points
    :rtype: list, list
    """
    cluster = phase.PhpHiCluster()
    cluster.setDmatrix(_distance_matrix(bit_matrix))
    cluster.createClusters()
    order = cluster.getClusters(1)[0]
    cut_points = cluster.getClusterCutPoints(num_modes)
    return (order, cut_points)