"""
Module for utilities for restraint generation.
Copyright Schrodinger, LLC. All rights reserved.
"""
import base64
import dataclasses
import json
from collections import Counter
from collections import defaultdict
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy as np
from schrodinger import structure
from schrodinger.application.desmond import cms
from schrodinger.application.desmond.packages import analysis
from schrodinger.application.desmond.packages import msys
from schrodinger.application.desmond.packages import traj
from schrodinger.structutils import analyze
RESTRAINT_KEY = 'restraint'
PERSISTENT_RESTRAINT_KEY = 'persistent'
_ENCODED_RESTRAINT_PROP = "s_desmond_restraint"
__all__ = [
'CrossLinkGenerationError',
'clear_encoded_restraints',
'get_encoded_restraints',
'set_encoded_restraints',
'b64_encode',
'b64_decode',
]
[docs]class CrossLinkGenerationError(Exception):
pass
[docs]def clear_encoded_restraints(cms_sys: cms.Cms, keep_persistent: bool = True):
"""
Remove encoded restraints from the `cms_sys`
:param cms_sys: "cms" to be processed.
:type cms_sys: `cms.Cms`
:param keep_persistent: "persistent" restraints disposition.
:type keep_persistent: bool
"""
if keep_persistent:
encoded = get_encoded_restraints(cms_sys)
if encoded:
dct = json.loads(b64_decode(encoded))
dct[RESTRAINT_KEY] = dct.get(PERSISTENT_RESTRAINT_KEY, {})
set_encoded_restraints(cms_sys, b64_encode(json.dumps(dct)))
else:
cms_sys.comp_ct[-1].property.pop(_ENCODED_RESTRAINT_PROP, None)
[docs]def get_encoded_restraints(cms_sys):
"""
Get encoded restraints from cms
:type cms_sys: `cms.Cms`
:rtype: `str`
"""
return cms_sys.comp_ct[-1].property.get(_ENCODED_RESTRAINT_PROP)
[docs]def set_encoded_restraints(cms_sys, restr):
"""
Store restraints in cms_sys object.
:type cms_sys: `cms.Cms`
:type restr: `str`
"""
cms_sys.comp_ct[-1].property[_ENCODED_RESTRAINT_PROP] = restr
[docs]def b64_encode(input_string: str) -> str:
"""
Encodes to a `str` rather than `bytes` so that the result can be set as
a string property of a CT.
:param input_string: string to be encoded
:return: base64 encoded input
"""
return base64.b64encode(input_string.encode()).decode()
[docs]def b64_decode(input_string: str) -> str:
return base64.b64decode(input_string)
def _check_asls(model: cms.Cms, ligand_asl: str, receptor_asl: str):
"""
:raise RuntimeError: If either ligand_asl or receptor_asl are not valid.
"""
def check_asl(asl, what_molecule):
atoms = model.select_atom(asl)
if 3 > len(atoms):
raise RuntimeError(
"ERROR: Expected the %s molecule to have at least 3 atoms, but "
"found %d." % (what_molecule, len(atoms)))
return set(atoms)
ligand_atoms = check_asl(ligand_asl, "ligand")
receptor_atoms = check_asl(receptor_asl, "receptor")
if ligand_atoms & receptor_atoms:
raise RuntimeError(
"""ERROR: ligand atoms and receptor atoms should NOT have overlaps.
Ligand ASL expression: %s
Receptor ASL expression: %s
Atoms selected by both expressions: %s""" %
(ligand_asl, receptor_asl, ligand_atoms & receptor_atoms))
def _check_restraint(restraint):
if restraint is None:
raise CrossLinkGenerationError("ERROR: Unable to find a suitable "
"crosslink restraint. Check the "
"trajectory for an unstable ligand.")
# Centroid/Interaction utility functions
# Make it hashable
[docs]@dataclasses.dataclass(eq=True, frozen=True)
class PLInteractionAids:
ligand_aid: int
receptor_n_aid: int
receptor_ca_aid: int
receptor_c_aid: int
@property
def receptor_aids(self) -> Tuple[int, int, int]:
return (self.receptor_n_aid, self.receptor_ca_aid, self.receptor_c_aid)
[docs]@dataclasses.dataclass
class CentroidData:
aids: List[int] # aids used to compute the centroid
diff: np.ndarray # delta vector between coords and the centroid
centroid: np.ndarray # centroid coords
def _get_centroid_data(ct: structure.Structure, asl: str) -> CentroidData:
"""
Return the data related to the centroid of the given asl.
"""
aids = analyze.evaluate_asl(ct, asl)
struct = ct.extract(aids)
coords = struct.getXYZ()
centroid = np.mean(coords, axis=0)
diff = coords - centroid
return CentroidData(aids, diff, centroid)
def _find_centroid_aid(ct: structure.Structure, asl: str) -> int:
"""
Return the atom index closest to the centroid.
"""
cd = _get_centroid_data(ct, asl)
dist = np.linalg.norm(cd.diff, axis=1)
index = np.argmin(dist)
return cd.aids[index]
def _find_aids_within_cutoff_of_centroid(ct: structure.Structure,
asl: str,
cutoff=2) -> List[int]:
"""
Return the atom indexes of that are within r_min + cutoff of the centroid.
:param cutoff: Atoms within this distance to the centroid are returned.
Default is 2 Angstrom.
"""
cd = _get_centroid_data(ct, asl)
dist = np.linalg.norm(cd.diff, axis=1)
index = np.argmin(dist)
r_min = dist[index]
centroid_aids = []
for i, l in enumerate(dist):
if l <= r_min + cutoff:
centroid_aids.append(cd.aids[i])
return centroid_aids
def _find_max_dist_aid(ct: structure.Structure, asl: str, ref_aid: int) -> int:
"""
Return the atom index for the atom that is
farthest from the reference atom.
:param asl: The asl for atoms to search.
:param ref_aid: The atom index of the reference atom.
"""
aids = analyze.evaluate_asl(ct, asl)
ref_st = ct.extract([ref_aid])
ref_coords = ref_st.getXYZ()
struct = ct.extract(aids)
coords = struct.getXYZ()
diff = coords - ref_coords
dist = np.linalg.norm(diff, axis=1)
index = np.argmax(dist)
return aids[index]
def _find_axis_aids(ct: structure.Structure,
asl: str,
aid1: int,
aid2: int,
aid3: int,
delta1=30,
delta2=60) -> List[int]:
"""
Given indexes of 3 atoms (not colinear),
find index of atoms (aid4) so that the aid2-aid3-aid4 angle is 90+/-delta1 and
the aid1-aid2-aid3-aid4 dihedral is 90+/-delta2.
:param delta1: interval for angle, default is 30 degrees.
:param delta2: interval for dihedral, default is 60 degrees.
:return: aid1, aid2, aid3: indexes for 3 atoms or empty list if no
match found.
"""
min_ang = 90 - delta1
max_ang = 90 + delta1
min_dih = 90 - delta2
max_dih = 90 + delta2
at1 = ct.atom[aid1]
at2 = ct.atom[aid2]
at3 = ct.atom[aid3]
angle = ct.measure(at1, at2, at3)
# Check for duplicate atoms
if len(set([at1, at2, at3])) != 3:
return []
# Check for colinear atoms
if angle < 5 or angle > 175:
return []
group_ids = []
for i in analyze.evaluate_asl(ct, asl):
at4 = ct.atom[i]
angle = ct.measure(at2, at3, at4)
dihed = ct.measure(at1, at2, at3, at4)
if angle > min_ang and angle < max_ang and \
dihed > min_dih and dihed < max_dih:
group_ids.append(i)
return group_ids
def _get_two_atoms(ct: structure.Structure, aid: int) -> List[Tuple[int, int]]:
"""
Given a structure and an aid, return all pairs of
connected heavy atoms that are bonded to the given atom
or bonded to the adjacent atom.
"""
aid_pairs = []
at1 = ct.atom[aid]
for at2 in at1.bonded_atoms:
if _bonded_heavy_atom_count(at2) >= 2:
for at3 in list(at1.bonded_atoms) + list(at2.bonded_atoms):
if at3 != at1 and at3 != at2 and _bonded_heavy_atom_count(
at3) >= 2:
aid_pairs.append((at2.index, at3.index))
return aid_pairs
def _bonded_heavy_atom_count(at: structure._structure._StructureAtom) -> int:
# Return the number of heavy atoms attached to the given atom.
return sum(bonded_at.atomic_number > 1 for bonded_at in at.bonded_atoms)
def _get_heavy_asl(asl: str) -> str:
return f'({asl}) and (not a.e H)'
def _get_backbone_aids(ct: structure.Structure,
receptor_aid: int) -> Tuple[int, int, int]:
"""
Given a structure and a protein atom index, return the
corresponding N, Ca, C backbone aids.
"""
res = ct.atom[receptor_aid].getResidue()
n = res.getBackboneNitrogen()
ca = res.getAlphaCarbon()
c = res.getCarbonylCarbon()
if not all([n, ca, c]):
return (None, None, None)
return (n.index, ca.index, c.index)
def _get_heavy_aid(ct: structure.Structure, aid: int) -> Optional[int]:
"""
Return the heavy atom that is attached to the given
`aid` and is not terminal. Return None if such an atom
could not be found.
"""
atm = ct.atom[aid]
# Avoid hydrogen and terminal heavy atoms
# Need to track this to prevent an infinite loop with CH3-CH3
searched_atoms = set()
while True:
nbond_heavy = _bonded_heavy_atom_count(atm)
if nbond_heavy == 0:
return None
elif nbond_heavy == 1:
for aa in atm.bonded_atoms:
if aa.atomic_number > 1:
if aa in searched_atoms:
return None
atm = aa
searched_atoms.add(atm)
else:
break
return atm.index
def _is_so2(ct: structure.Structure, aid: int) -> bool:
"""
Return True if the atom is sulfur and it is bonded to two oxygens.
"""
lig_at = ct.atom[aid]
num_oxygen = 0
if lig_at.atomic_number == 16 and _bonded_heavy_atom_count(lig_at) == 4:
num_oxygen = sum(ati.atomic_number == 8 for ati in lig_at.bonded_atoms)
return num_oxygen == 2
def _get_protein_ligand_interaction_freq_dict(
msys_model: "msys.System", # noqa: F821
cms_model: cms.Cms,
tr: List["traj.TrajFrame"],
ligand_asl: str,
receptor_asl: str,
num_traj_segments: int = 1) -> Dict[PLInteractionAids, List[float]]:
"""
Return the normalized frequency for the interactions between
the protein and ligand. The keys are `PLInteractionAids`.
:param num_traj_segments: Number of segments to split the trajectory into
prior to running the analysis. The frequencies are computed
for each segment.
"""
nfr = (len(tr) + 1) // num_traj_segments
trs = [tr[i * nfr:(i + 1) * nfr] for i in range(num_traj_segments)]
# ASLs
receptor_aids = analyze.evaluate_asl(cms_model, receptor_asl)
ligand_aids = analyze.evaluate_asl(cms_model, ligand_asl)
# Run protein-ligand interaction analyzers
freqs = []
for tr in trs:
analyzer1 = analysis.HydrogenBondFinder(msys_model, cms_model,
receptor_aids, ligand_aids)
analyzer2 = analysis.SaltBridgeFinder(msys_model, cms_model,
receptor_aids, ligand_aids)
result1 = analysis.analyze(tr, analyzer1)
result2 = analysis.analyze(tr, analyzer2)
results = [a + b for a, b in zip(result1, result2)]
# Get the frequency for each pair of interactions
freq = Counter()
for result in results:
for (pro_aid, lig_aid) in result:
if lig_aid not in ligand_aids:
pro_aid, lig_aid = lig_aid, pro_aid
pro_n_aid, pro_ca_aid, pro_c_aid = _get_backbone_aids(
cms_model.fsys_ct, pro_aid)
if pro_n_aid is None:
# Skip interaction with terminal residue
continue
lig_aid = _get_heavy_aid(cms_model.fsys_ct, lig_aid)
if lig_aid is None:
print('Could not find attached heavy atom.')
return None
pair = PLInteractionAids(lig_aid, pro_n_aid, pro_ca_aid,
pro_c_aid)
freq[pair] += 1
freqs.append(freq)
# Corner case: interaction missing from trajectory segment.
# Make sure all segments have the same interaction keys.
pairs = {pair for freq in freqs for pair in freq}
for freq in freqs:
for pair in pairs:
if pair not in freq:
freq[pair] = 0.0
# No hydrogen bond or salt bridge interactions found
if not pairs:
return None
# Normalize the frequency for each segment and store as a list
result = defaultdict(list)
for freq, tr in zip(freqs, trs):
for k in freq.keys():
freq[k] = freq[k] / (len(tr) or 1)
result[k].append(freq[k])
return result