"""
Desmond workup techniques, all based on actual numerical results. Intended for
import and use as a test workup.
Running directly allows execution of any of the workup methods
Methods for comparing the following Desmond-style information:
    - `Metadynamics FES files<desmond_compare_FES>`
    - `FEP energies from FEP 'result' files<desmond_FEP_energies>`
    - Surface area of membrane systems
    - .ene files:
        - `similarity to a reference ene file<desmond_compare_ene>`
    - log files: `Calculation speed (ns/day)<desmond_speed>`
    - `st2 files<desmond_st_stats>`: comparison of median, mean, and standard
        deviation to reference values
    - `deltaE.txt files<desmond_compare_deltaE>`: compare two deltaE.txt files
Classes and helper functions have 'private' names, as only the workup methods
are intended to be imported and used as test workups.
$Revision 0.1 $
@TODO: improve speed test to print avg ns/day.
@TODO: improve speed test to be dependent on cpu^0.8
@copyright: (c) Schrodinger, Inc. All rights reserved.
"""
import gzip
import os
import re
import sys
import tarfile
from collections import Counter
from collections import defaultdict
from functools import partial
from past.utils import old_div
from typing import TYPE_CHECKING
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import numpy
import pytest
from more_itertools import pairwise
from schrodinger import structure
from schrodinger.application.desmond import cmj
from schrodinger.application.desmond import constants
from schrodinger.application.desmond import stage
from schrodinger.utils import sea
from . import standard_workups
if TYPE_CHECKING:
    from schrodinger.application.scisol.packages.fep import graph
_version = "$Revision 0.1 $"
REPLICA = re.compile(r"\[T(\d+), *T(\d+)\] *: *\[(.+), *(.+)\]")
# *****************************************************************************
# METADYNAMICS
def _FES(filename):
    """Returns a numpy array based on a FES format file."""
    fes1 = []
    with open(filename) as f:
        for line in f:
            line = line.strip()
            if "#" in line:
                continue
            elif line:
                fes1.append([float(x) for x in line.split()])
    return numpy.array(fes1)
[docs]def desmond_compare_FES(file1, file2, tolerance=0.1):
    """
    Compare metadynamics FES files.
    Tolerance is optional, and defaults to 0.1
    """
    fes1 = _FES(file1)
    fes2 = _FES(file2)
    try:
        rmsd = fes2 - fes1
        rmsd = rmsd * rmsd
        rmsd = old_div(numpy.sum(rmsd.flat), len(fes1.flat))
        rmsd = numpy.sqrt(rmsd)
        if rmsd > tolerance:
            msg = "FES were different, RMSD=%.2f" % rmsd
            raise AssertionError(msg)
        return True
    except ValueError:
        msg = "FES shapes were different!"
        raise AssertionError(msg) 
#THIS CODE WILL GENERATE A FES:
#from schrodinger.application.desmond.meta import MetaDynamicsAnalysis
#meta_data = MetaDynamicsAnalysis(, inp_fname=opts.input)
#FES = meta_data.computeFES(outputname)
#
#alternately, just add an analysis stage to a metadynamics job!
# *****************************************************************************
# FEP
[docs]def desmond_FEP_energies(filename, reference, tolerance=0.5):
    """
    Compares the deltaF in an FEP result file to a standard.
    """
    with open(filename) as result:
        for line in result:
            if "deltaF" in line:
                delta_f = float(line.split()[3])
                if abs(delta_f - reference) > tolerance:
                    msg = "   Different deltaF: {} vs {} ".format(
                        delta_f, reference)
                    raise AssertionError(msg)
                else:
                    return True
    msg = "   deltaF not found in file %s" % filename
    raise AssertionError(msg) 
# *****************************************************************************
# Membrane Systems
[docs]def desmond_check_surface_area(filename,
                               reference,
                               tolerance=0.5,
                               side1="r_chorus_box_ax",
                               side2="r_chorus_box_bx"):
    """
    Compares surface area to a standard with optional tolerance and optional
    choice of box sides.  Default sides are a and b.
    example usage: desmond_check_surface_area('filename', 7.5, tolerance=0.1)
    """
    struct = next(structure.MaestroReader(filename))
    side1 = struct.property[side1]
    side2 = struct.property[side2]
    area = side1 * side2
    if old_div(abs(area - reference), reference) > tolerance:
        msg = f"   Areas were different: {area} vs {reference}"
        raise AssertionError(msg)
    return True 
[docs]def desmond_tgz_file_list(tgz_fname, *fnames_to_check):
    import tarfile
    with tarfile.open(tgz_fname) as tf:
        am = tf.getmembers()
        mn = [m.name for m in am]
        fnames_to_check = list(fnames_to_check)
        fnames_to_check.sort()
        mn.sort()
        if mn == fnames_to_check:
            return True
        msg = ' '.join((mn, fnames_to_check))
        raise AssertionError(msg) 
# *****************************************************************************
# .ene files
class _EneFile:
    """
    Class to find and contain the information in a Desmond .ene file.  Also
    does comparisons based on:
        - compare initial and final energies to reference
        - compare temperature profiles (useful for SAmd)
        - number of time points and their separation
    """
    def __init__(self, filename=None, tol=10.0):
        self.eneData = {}
        self.eneUnits = {}
        self.tol = tol
        if filename:
            self.load_system(filename)
        else:
            self._filename = None
    def filename(self, filename=None):
        """
        Gets and sets the filename.  Prevents having a loaded system which
        does not match the file name.
        """
        if filename:
            self.load_system(filename)
        else:
            return self._filename
    def load_system(self, filename):
        """Loads the information from file filename into the EneFile object."""
        self._filename = filename
        name_rows = []  #list describes the property name in each row
        with open(self._filename) as ene_file:
            for line in ene_file:
                if "#" == line[0]:
                    array = line[1:].split()
                    if '0' == array[0][0]:
                        for elem in array:
                            if ":" in elem:
                                name = elem.split(':')[-1]
                                name_rows.append(name)
                                self.eneData[name] = []
                            elif "(" == elem[0] and ")" == elem[-1]:
                                #Add the name of the unit based on the most
                                #recent quantity name found
                                self.eneUnits[name_rows[-1]] = elem[1:-1]
                            else:
                                raise Exception(
                                    "unknown row type in row keys list")
                elif name_rows:
                    for i, datum in enumerate(line.split()):
                        self.eneData[name_rows[i]].append(float(datum))
    def quantity_average(self, quantity):
        """
        Finds the average of a quantity.
        """
        return numpy.average(self.eneData[quantity])
    def compare_units(self, other):
        """
        Compares two EneFile objects based on whether the quantities are
        measured in the same units in both files.
        """
        for value in set(list(self.eneUnits) + list(other.eneUnits)):
            if self.eneUnits.get(value, "") != other.eneUnits.get(value, ""):
                msg = ("  Units did not match for %s (%s vs %s)" %
                       (value, self.eneUnits.get(
                           value, ""), other.eneUnits.get(value, "")))
                raise AssertionError(msg)
        return True
    def compare_quantity(self, other, quantity):
        """
        Compares two EneFile objects based on the initial and final and average value of a
        named quantity.
        """
        max_diff = 0.1**self.tol
        if abs(self.eneData[quantity][0] -
               other.eneData[quantity][0]) > max_diff:
            msg = "  Initial values for {} differ: {} vs {}".format(
                quantity, self.eneData[quantity][0], other.eneData[quantity][0])
            raise AssertionError(msg)
        if abs(self.eneData[quantity][-1] -
               other.eneData[quantity][-1]) > max_diff:
            msg = "  Final values for {} differ: {} vs {}".format(
                quantity, self.eneData[quantity][-1],
                other.eneData[quantity][-1])
            raise AssertionError(msg)
        if abs(
                self.quantity_average(quantity) -
                other.quantity_average(quantity)) > max_diff:
            msg = "  Average values for {} differ: {} vs {}".format(
                quantity, self.quantity_average(quantity),
                other.quantity_average(quantity))
            raise AssertionError(msg)
        return True
    def compare_times(self, other):
        """
        Compares two EneFile objects based on the number of points in each
        simulation as well as the final time point.
        """
        r_value = (len(self.eneData['time']) == len(other.eneData['time']) and
                   self.eneData['time'][-1] == other.eneData['time'][-1])
        if not r_value:
            raise AssertionError("  Time profiles did not match")
        return True
    def compare_temperature_profile(self, other):
        """
        Compares two EneFile objects based on the temperatures at each
        timepoint.
        """
        temps_self = numpy.array(self.eneData["T"])
        temps_other = numpy.array(other.eneData["T"])
        diffs = numpy.abs(temps_self - temps_other)
        if max(diffs) > 40:
            msg = "  The maximum difference in Temperature at any given time"\
                    " point was large: {}".format(max(diffs))
            raise AssertionError(msg)
        return True
[docs]def desmond_compare_ene(file1, file2, tolerance, *quantities):
    """
    Compares two Desmond ene files.
    Checks that any quantities requested from the ene file match
    within some tolerance number of decimal points.
    usage desmond_compare_ene('file1.ene', 'file2.ene', 2, 'P', 'E_c')
    """
    ene1 = _EneFile(file1, tolerance)
    ene2 = _EneFile(file2, tolerance)
    successes = list(map(partial(ene1.compare_quantity, ene2), quantities))
    return all(successes) 
[docs]def desmond_compare_eneseqs(file1,
                            file2,
                            *,
                            columns=None,
                            min_length=10,
                            atol=1e-2,
                            rtol=1e-5,
                            mean_atol=1e-4,
                            mean_rtol=1e-5):
    """
    Compare two columns between two Desmond ene files at matching time points.
    :param columns: Column names to compare. Use all by default.
    :type columns: sequence of str
    :param min_length: Require at least that many matching time points.
    :type min_length: int
    :param atol: Absolute tolerance (see `numpy.isclose`).
    :type atol: float
    :param rtol: Relative tolerance (see `numpy.isclose`).
    :type rtol: float
    :param mean_atol: Absolute tolerance for mean values (see `numpy.isclose`).
    :type mean_atol: float
    :param mean_rtol: Relative tolerance for mean values (see `numpy.isclose`).
    :type mean_rtol: float
    """
    TIME_COLUMN = 'time'
    ene1 = _EneFile(file1)
    ene2 = _EneFile(file2)
    if columns is None:
        assert ene1.eneData.keys() == ene2.eneData.keys(), \
            
"  Column names do not match"
        columns = ene1.eneData.keys() - {TIME_COLUMN}
    else:
        assert set(columns) <= ene1.eneData.keys(), \
            
"  Some of the columns are missing in file1"
        assert set(columns) <= ene2.eneData.keys(), \
            
"  Some of the columns are missing in file2"
    for (label, sequence) in (('file1', ene1.eneData[TIME_COLUMN]),
                              ('file2', ene1.eneData[TIME_COLUMN])):
        assert all(a < b for a, b in pairwise(sequence)), \
            
f"  Non-monotonic time values in {label}"
    common_times, idx1, idx2 = numpy.intersect1d(ene1.eneData[TIME_COLUMN],
                                                 ene2.eneData[TIME_COLUMN],
                                                 assume_unique=True,
                                                 return_indices=True)
    num_common = len(idx1)
    assert min_length > 0
    assert num_common >= min_length, \
        
f"  Too few matching time points: {num_common} < {min_length}"
    for c in sorted(columns):
        x1 = numpy.asarray(ene1.eneData[c])[idx1]
        x2 = numpy.asarray(ene2.eneData[c])[idx2]
        close = numpy.isclose(x1, x2, atol=atol, rtol=rtol)
        if not close.all():
            not_close = [arr[~close] for arr in (common_times, x1, x2)]
            print(f"Divergent values of '{c}' at:")
            for t_, x1_, x2_ in zip(*not_close):
                print(f"t={t_}: {x1_} vs {x2_}")
        assert close.all(), \
            
f"  Divergent values of '{c}', see stdout for details"
        mean1, mean2 = x1.mean(), x2.mean()
        assert numpy.isclose(mean1, mean2, atol=mean_atol, rtol=mean_rtol), \
            
f"  Divergent values of mean('{c}'): {mean1} vs {mean2}" 
# *****************************************************************************
# log files
[docs]def desmond_speed(filename, reference_rate, tolerance=0.05):
    """
    Compares the average rate per step in ns/day against a standard.
    example usage: desmond_speed('file1.log', 7.5, tolerance=0.1)
    :param reference_rate: Value to be compared against in ns/day
    :param tolerance: Tolerance. Default is 5%.
    """
    with open(filename) as logfile:
        for line in logfile:
            if "Total rate per step" in line:
                rate = float(line.split()[-2])
                if old_div(abs(rate - reference_rate),
                           reference_rate) > tolerance:
                    msg = ("  Rates: new: %.1f, reference: %.1f" %
                           (rate, reference_rate))
                    raise AssertionError(msg)
                else:
                    return True
    msg = "  Total rate per step not found."
    raise AssertionError(msg) 
[docs]def desmond_stage_time(log_file, stage_num, max_time):
    """
    :param log_file : Name of log file to be analyzed.
    :type log_file : str
    :param stage_num : Stage number
    :type stage_num : int
    :param max_time : Maximum acceptable time for the stage in seconds
    :type max_time : float
    :return : `True` if the stage completed successfully within `max_time` seconds, or `False` if not.
    """
    search_str = 'Stage %d duration:' % stage_num
    try:
        with open(log_file) as file_data:
            stage_duration = [
                line.strip(search_str)
                for line in file_data
                if re.search(search_str, line)
            ][-1]
    except IndexError:
        return False
    with open(log_file) as fh:
        multisim_complete = fh.read().find(
            "\nStage %d completed successfully." % stage_num) != -1
    h, m, s = [e[:-1] for e in stage_duration.split()]
    duration = float(h) * 3600 + float(m) * 60 + float(s)
    return ((duration < max_time) and multisim_complete) 
# *****************************************************************************
# st2 files.
#compare averages and ranges.
[docs]def desmond_st_stats(filename, *options):
    """
    Compares the mean, median and/or standard deviation of a desmond st2 file to
    reference values.
    usage: desmond_st_stats('file.st2', 'mean=1.2', 'median=2.3', 'stddev=1.2', 'tol=0.2')
    """
    with open(filename) as fh:
        cfg = sea.Map(fh.read())
    meanRef, medianRef, stddevRef = None, None, None
    tolerance = 0.4
    for a in options:
        key, val = a.split("=")
        val = float(val)
        if key.lower() == "mean":
            meanRef = val
        elif key.lower() == "median":
            medianRef = val
        elif key.lower() == "stddev":
            stddevRef = val
        elif key.lower()[:3] == "tol":
            tolerance = val
    try:
        rvalue = True
        values = []
        for st in cfg.Keywords:
            for key in st:
                if key == "Angle" or key == "Torsion":
                    values = [x.val for x in st[key]["Result"]]
                    values = _normalize_angular_values(values)
                else:
                    values = [x.val for x in st[key]["Result"]]
        errors = []
        if meanRef:
            _compare_stats(values, meanRef, tolerance, numpy.mean, "Mean",
                           errors)
        if medianRef:
            _compare_stats(values, medianRef, tolerance, numpy.median, "Median",
                           errors)
        if stddevRef:
            _compare_stats(values, stddevRef, tolerance, numpy.std,
                           "Standard deviation", errors)
        if errors:
            raise AssertionError('\n'.join(errors))
    except Exception as e:
        raise AssertionError(str(e)) 
def _compare_stats(values, reference, tolerance, function, name, errors):
    value = function(values)
    if abs(reference - value) > tolerance:
        msg = f"   {name} differs: {value} vs. {reference}"
        errors.append(msg)
def _normalize_angular_values(array):
    """
    Centers an array of angles in the unit of degree to minimize its standard
    deviation.
    :param array : Array to be (possibly) modified.
    :type array : list
    :return : Values, some of which (possibly) offset (by 360 deg) to minimize std dev.
    :rtype : list
    """
    good_array = array
    stddev = numpy.std(good_array)
    for shift in numpy.arange(0, 360, 2.7):
        tmp = list(((elem + shift) % 360 - shift) for elem in array)
        if numpy.std(tmp) < stddev:
            good_array = tmp
            stddev = numpy.std(good_array)
    good_array -= (numpy.mean(good_array) // 360) * 360
    return good_array.tolist()
[docs]def read_dE_file(fname):
    """
    Read a dE file and return a dict from time to energy_dict;
    energy_dict is a dict from (lambda_win, lambda_win) to energy value
    """
    dic, energy = {}, {}
    with open(fname) as fh:
        for line in fh:
            line = line.strip()
            if 'Time' == line[:4]:
                time = line[6:]
                dic = energy[time] = {}
            elif '[' == line[0]:
                match = REPLICA.match(line)
                if match:
                    r0, r1, ef, er = match.group(1, 2, 3, 4)
                    r0 = int(r0)
                    r1 = int(r1)
                    dic[(r0, r1)] = ef
                    dic[(r1, r0)] = er
    return energy 
[docs]def desmond_compare_deltaE(file1, file2):
    """
    Compares two Desmond deltaE.txt files.
    Checks time and energy. No tolerance (exact sameness).
    usage: desmond_compare_deltaE('file1', 'file2')
    """
    energy1 = read_dE_file(file1)
    energy2 = read_dE_file(file2)
    return energy1 == energy2
    """
    if len(energy1) != len(energy2):
        return False
    for k, d1 in energy1.iteritems():
        try:
            d2 = energy2[k]
        except KeyError:
            return False
        if len(d1) != len(d2):
            return False
        for l, e1 in d1.iteritems():
            try:
                e2 = d2[l]
            except KeyError:
                return False
            if e1 != e2:
                return False
    return True
    """ 
[docs]def desmond_compare_energy_group(file1, file2, tolerance=0):
    """
    Compares two Desmond enegrp.dat files.
    Checks all values upto tolerance value.
    usage: desmond_compare_energy_group('file1', 'file2', tolerance=0)
    """
    def compare_value_line(l1, l2):
        k1 = l1[:l1.index('(')].strip()
        k2 = l2[:l2.index('(')].strip()
        v1 = l1[l1.index(')') + 1:].strip()
        v2 = l2[l2.index(')') + 1:].strip()
        return k1 == k2 and my_compare(v1, v2)
    def compare_time_line(l1, l2):
        d1 = read_time_line(l1)
        d2 = read_time_line(l2)
        if len(d1) != len(d2):
            return False
        for k, v1 in d1.items():
            try:
                v2 = d2[k]
            except KeyError:
                return False
            if not my_compare(v1, v2):
                return False
        return True
    def read_time_line(l):
        d = {}
        for s in l.split():
            k, v = s.split('=')
            d[k] = v
        return d
    def my_compare(v1, v2):
        v1 = float(v1)
        v2 = float(v2)
        return abs(v1 - v2) <= tolerance
    with open(file1) as fh1,\
            
open(file2) as fh2:
        for i, (line1, line2) in enumerate(zip(fh1, fh2)):
            if '(' in line1:
                if not compare_value_line(line1, line2):
                    msg = f'line {i} {line1} differ from {line2}'
                    raise AssertionError(msg)
            else:
                if not compare_time_line(line1, line2):
                    msg = f'line {i} {line1} differ from {line2}'
                    raise AssertionError(msg)
        return True 
[docs]def desmond_test_mean_num_waters(tarname, limits):
    """
    Check that the mean number of waters in the simulation is within the given
    limits.
    :param tarname: the input tar file name
    :type tarname: str
    :param limits: a 2-tuple containing the min and max values
    :type limits: tuple
    """
    with tarfile.open(tarname) as tarball:
        for name in tarball.getnames():
            if name.endswith('.log'):
                break
        else:
            raise ValueError("No logfile found")
        # Read the average number of waters at the end of each GCMC plugin
        # iteration from the logfile of the GCMC simulation.
        nwaters = []
        # Format: Total solvent in simulation box: <N> Average: <x>
        check_phrase = 'Total solvent in simulation box: '
        # compatibility with quiet format:
        # GCMC Moves accepted ... Total solvent molecules: <N>
        check_phrase_quiet = 'GCMC moves accepted: '
        with tarball.extractfile(name) as f:
            for l in f:
                # extractfile doesn't give decoded output
                decoded = l.decode('utf-8')
                if decoded.startswith(check_phrase):
                    # Extract the average number of particles at the end of
                    # a given GCMC plugin iteration
                    nwaters.append(float(decoded.split()[-1]))
                elif decoded.startswith(check_phrase_quiet):
                    # Extract the instantaneous number of particles
                    # (not the average) at regular intervals
                    nwaters.append(int(decoded.split()[-1]))
    # Average over possibly many GCMC plugin iterations
    mean_num_waters = numpy.mean(nwaters)
    nmin, nmax = limits
    return nmin < mean_num_waters < nmax 
[docs]def desmond_check_fep_main_msj(main_msj_fname: str,
                               fep_type: constants.FEP_TYPES):
    """
    Basic checks for the main msj.
    :param main_msj_fname: The input main msj file name.
    :param fep_type: Type of fep used to run the job.
    """
    from schrodinger.application.scisol.packages.fep.graph import Legtype
    legs = (Legtype.COMPLEX.value, Legtype.SOLVENT.value)
    job_name = os.path.splitext(os.path.basename(main_msj_fname))[0]
    msj_as_sea = cmj.msj2sea(main_msj_fname)
    found_mapper = False
    found_launcher = False
    found_fep_cleanup = False
    found_protein_generator = False
    mapper_name = {
        constants.FEP_TYPES.SMALL_MOLECULE: stage.FepMapper.NAME,
        constants.FEP_TYPES.METALLOPROTEIN: stage.FepMapper.NAME,
        constants.FEP_TYPES.COVALENT_LIGAND: stage.CovalentFepMapper.NAME,
        constants.FEP_TYPES.PROTEIN_STABILITY: stage.ProteinFepMapper.NAME,
        constants.FEP_TYPES.PROTEIN_SELECTIVITY: stage.ProteinFepMapper.NAME,
        constants.FEP_TYPES.LIGAND_SELECTIVITY: stage.ProteinFepMapper.NAME,
    }
    has_fragment_linking = fep_type == constants.FEP_TYPES.SMALL_MOLECULE
    for s in msj_as_sea.stage:
        msg_base = f"{main_msj_fname}: {s.__NAME__}"
        if s.__NAME__ == mapper_name[fep_type]:
            found_mapper = True
            # Check mapper args
            if fep_type == constants.FEP_TYPES.SMALL_MOLECULE:
                pass
            elif fep_type == constants.FEP_TYPES.COVALENT_LIGAND:
                assert len(s.graph_file.val), \
                    
f"{msg_base}.graph_file is not set."
            elif fep_type == constants.FEP_TYPES.METALLOPROTEIN:
                assert len(s.mp.val), \
                    
f"{msg_base}.mp is not set."
                assert sum(['i_fep_' in v for v in s.mp.val]), \
                    
f"{msg_base}.mp is not set correctly, missing 'i_fep_' in name."
            elif fep_type in constants.PROTEIN_FEP_TYPES:
                assert s.fep_type.val == fep_type, \
                    
f"{msg_base}.fep_type is not set correctly. Found {s.fep_type.val}, expected {fep_type}."
                if fep_type in [
                        constants.FEP_TYPES.PROTEIN_SELECTIVITY,
                        constants.FEP_TYPES.LIGAND_SELECTIVITY
                ]:
                    assert s.asl.val, \
                        
f"{msg_base}.asl is empty for selectivity calculation."
            else:
                assert False, f"Unknown fep_type: {fep_type}"
        elif s.__NAME__ == stage.FepLauncher.NAME:
            found_launcher = True
            for ileg, leg in enumerate(legs):
                expected_msj = {
                    constants.SIMULATION_PROTOCOL.CHARGED: f'{job_name}_chg_{leg}.msj',
                    constants.SIMULATION_PROTOCOL.FORMALCHARGED: f'{job_name}_chg_{leg}.msj',
                    constants.SIMULATION_PROTOCOL.COREHOPPING: f'{job_name}_corehopping_{leg}.msj',
                    constants.SIMULATION_PROTOCOL.MACROCYCLE_COREHOPPING: f'{job_name}_corehopping_{leg}.msj',
                    constants.SIMULATION_PROTOCOL.DEFAULT: f'{job_name}_{leg}.msj',
                }
                if has_fragment_linking:
                    expected_msj[
                        constants.SIMULATION_PROTOCOL.
                        FRAGMENT_LINKING] = f'{job_name}_fragment_linking_{leg}.msj'
                for name, msj_fname in expected_msj.items():
                    assert name in s.dispatch.val, f"Missing {name} dispatcher."
                    assert msj_fname in s.dispatch[name].val[
                        ileg], f"Missing {msj_fname} in {name} dispatcher."
                assert s.dispatch.val.keys() == expected_msj.keys(), \
                    
f"Additional dispatchers found: " \
                    
f"{set(s.dispatch.val.keys()) - set(expected_msj.keys())}, " \
                    
"please update desmond_workups and the stu tests."
            if has_fragment_linking:
                for ileg, leg in enumerate(constants.FRAGMENT_LINKING_JOBS.VAL):
                    msj_fname = f'{job_name}_fragment_linking_{leg}.msj'
                    assert msj_fname in s.dispatch[name].val[
                        ileg], f"Missing {msj_fname} in {name} dispatcher."
        elif s.__NAME__ == stage.ProteinMutationGenerator.NAME:
            assert fep_type in constants.PROTEIN_FEP_TYPES, f"{msg_base}: not a protein fep calculation!"
            assert s.mutation_file.val, f"{msg_base}.mutation_file is empty."
            found_protein_generator = True
        found_fep_cleanup = found_fep_cleanup or s.__NAME__ == stage.FepMapperCleanup.NAME
    assert found_mapper, "Mapper stage not found"
    assert found_launcher, "Launcher stage not found"
    assert found_fep_cleanup, "Cleanup stage not found"
    if fep_type in constants.PROTEIN_FEP_TYPES:
        assert found_protein_generator, "Protein mutation generator stage not found"
    return True 
_EXPECTED_LAMBDAS = {
    constants.SIMULATION_PROTOCOL.CHARGED: 'charge:24',
    constants.SIMULATION_PROTOCOL.FORMALCHARGED: 'charge:24',
    constants.SIMULATION_PROTOCOL.COREHOPPING: 'flexible:16',
    constants.SIMULATION_PROTOCOL.MACROCYCLE_COREHOPPING: 'flexible:16',
    constants.SIMULATION_PROTOCOL.DEFAULT: 'default:12',
    constants.SIMULATION_PROTOCOL.FRAGMENT_LINKING: 'flexible:16',
    constants.FRAGMENT_LINKING_JOBS.VAL.FRAGMENT_HYDRATION: 'default:24'
}
def _check_edge_sid(edge: "graph.Edge", exp_num_lambdas_complex: int,
                    exp_num_lambdas_solvent: int):
    """
    Check the edge's sid report to check for the appropriate number of
    replicas for both the complex and solvent legs
    Related cases: DESMOND-8836, DESMOND-8811, DESMOND-8805,
    DESMOND-8674, DESMOND-8243, DESMOND-8224, DESMOND-8208.
    :param edge: graph.Edge object to check
    """
    from schrodinger.application.scisol.packages.fep.graph import Legtype
    fep_type = edge.graph.fep_type
    legs = (Legtype.COMPLEX.value, Legtype.SOLVENT.value)
    sids = [edge.complex_sid, edge.solvent_sid]
    expected_num_replicas = [exp_num_lambdas_complex, exp_num_lambdas_solvent]
    for leg, sid, exp_num in zip(legs, sids, expected_num_replicas):
        assert sid and len(
            sid) > 0, f"Edge {edge.short_id} is missing {leg}_sid."
        sid_obj = sea.Map(sid)
        checked_fep_simulation = False
        checked_num_replicas = False
        for k in sid_obj['Keywords']:
            if 'FEPSimulation' in k:
                assert k['FEPSimulation']['PerturbationType'].val == fep_type, \
                    f'Could not find "{fep_type}" in {leg} sid perturbation_type for {edge.short_id}.'
                checked_fep_simulation = True
            if 'Replica' in k:
                # DESMOND-8811
                assert len(k['Replica'].val) == exp_num, \
                    f"Job ran with wrong number of lambdas. Found {len(k['Replica'].val)}, " \
                    f"expected {exp_num}."
                checked_num_replicas = True
        assert checked_fep_simulation, "Did not find FEPSimulation in {edge.short_id} sid."
        assert checked_num_replicas, "Did not find Replica in {edge.short_id} sid."
def _check_edge_parched(edge: "graph.Edge", min_frames=21):
    """
    Check the parched trajectories and representative structures
    for a given edge.
    Related cases: DESMOND-8839, DESMOND-8819, DESMOND-8812, DESMOND-8596,
    DESMOND-8499, DESMOND-8279, DESMOND-8198.
    :param edge: graph.Edge object to check
    :param min_frames: Minimum number of frames in the parched trajectories.
                       Default is 21 for the number of frames in a
                       standard 5 ns FEP job.
    """
    from schrodinger.application.desmond.packages import traj_util
    from schrodinger.application.scisol.packages.fep.graph import Legtype
    legs = (Legtype.COMPLEX, Legtype.SOLVENT)
    assert edge.graph.fmpdb is not None, "graph has no fmpdb file."
    if edge.graph.fep_type == constants.FEP_TYPES.ABSOLUTE_BINDING:
        lambdas = [0]
    else:
        lambdas = [0, 1]
    edge.graph.fmpdb.open('.')
    try:
        fmpdb_fname = edge.graph.fmpdb.filename
        for leg in legs:
            trajs = edge.trajectories(leg)
            assert trajs, f"{fmpdb_fname} is missing the trajectory for {edge.short_id} {leg.value}."
            for ilambda in lambdas:
                lambda_msys, lambda_cms, lambda_traj = traj_util.read_cms_and_traj(
                    trajs[ilambda][0])
                try:
                    assert len(lambda_traj) >= min_frames, \
                        f"Trajectory for lambda {ilambda}, edge {edge.short_id}, leg {leg.value} has {len(lambda_traj)} frames, expected at least {min_frames} frames."
                    for (ct, _) in traj_util.extract_atoms(
                            lambda_cms,
                            f'water and not (atom.{constants.ALCHEMICAL_ION} or atom.{constants.CT_TYPE} "{constants.CT_TYPE.VAL.ALCHEMICAL_WATER}")',
                            lambda_traj):
                        assert len(ct.molecule) == 200, \
                        f"Trajectory for lambda {ilambda}, edge {edge.short_id}, leg {leg.value} "\
                        f" has wrong number of waters. Found {len(ct.molecule)}, expected 200."
                except AssertionError:
                    raise
                finally:
                    if lambda_traj:
                        # Close the reader to close the filehandle,
                        # otherwise fmpdb.close() can fail.
                        # DESMOND-8868
                        lambda_traj[0].source().close()
        # Check representative structure (complex leg only).
        leg = Legtype.COMPLEX
        repr_fnames = edge.representative_strucs(leg)
        assert repr_fnames is not None, f"{fmpdb_fname} is missing the representative struc for {edge.short_id} {leg.value}."
        for ilambda in lambdas:
            lambda_repr_ct = structure.Structure.read(repr_fnames[ilambda])
            assert lambda_repr_ct.atom_total > 0, \
                f"Representative structure for lambda {ilambda}, edge {edge.short_id}, leg {leg.value}, has no atoms."
    except AssertionError:
        raise
    finally:
        edge.graph.fmpdb.close()
def _check_graph_environment(g: "graph.Graph", membrane=False):
    """
    Check the graph environment structures.
    :param g: The graph object to check.
    :param membrane: Set to True if this is a membrane fep run.
                     Default is False.
    """
    from schrodinger.application.scisol.packages.fep import fepmae
    if g.fep_type == constants.FEP_TYPES.SMALL_MOLECULE:
        environments = ["receptor"]
    elif g.fep_type == constants.FEP_TYPES.METALLOPROTEIN:
        environments = ["receptor"]
    elif g.fep_type == constants.FEP_TYPES.COVALENT_LIGAND:
        environments = []
    elif g.fep_type == constants.FEP_TYPES.PROTEIN_STABILITY:
        environments = []
    elif g.fep_type == constants.FEP_TYPES.PROTEIN_SELECTIVITY:
        environments = ["receptor"]
    elif g.fep_type == constants.FEP_TYPES.LIGAND_SELECTIVITY:
        environments = ["receptor"]
    elif g.fep_type == constants.FEP_TYPES.ABSOLUTE_BINDING:
        environments = ["receptor"]
    else:
        raise ValueError(f'unknown fep_type: {g.fep_type}')
    if membrane:
        environments.extend(["membrane", "solvent"])
    for environment in environments:
        env_ct = getattr(g, f'{environment}_struc')
        assert env_ct is not None, f"graph {environment}_struc is missing."
        results = fepmae.filter_receptors_and_ligands([env_ct])
        results_dict = {
            'receptor': results[0],
            'solvent': results[1],
            'membrane': results[2],
            'ligands': results[3]
        }
        # For ligand selectivity, the ligand is stored in the receptor_struc
        if environment == "receptor" and g.fep_type == constants.FEP_TYPES.LIGAND_SELECTIVITY:
            assert len(results_dict['ligands']) == 1, \
                f"{g.fep_type} should have one ligand as the receptor. Found {len(results_dict['ligands'])} ligand(s). {results_dict}"
            results_dict['receptor'] = results_dict['ligands'][0]
            results_dict['ligands'] = []
        for k, v in results_dict.items():
            if k == environment:
                found_ct = results_dict[environment]
                found_title = found_ct.title if hasattr(found_ct,
                                                        'title') else found_ct
                env_title = env_ct.title if hasattr(env_ct, 'title') else env_ct
                assert env_ct is v, \
                    f"graph {environment}_struc is not correct. Found {found_title}, expected {env_title}."
            else:
                assert not v, \
                    f"graph {environment}_struc is missing or has wrong type."
    found_env_count = len([ct for ct in g.environment_struc if ct is not None])
    exp_env_count = len(environments)
    assert found_env_count == exp_env_count, \
        f'graph environment_struc is not correct. Found {found_env_count}, expected {exp_env_count} ct(s).'
def _check_graph_nodes(g: "graph.Graph", node_ids: List[str]):
    """
    Check the node structures for a graph.
    :param g: The graph object to check.
    :param node_ids: List of expected node ids.
    """
    assert g.number_of_nodes() > 0, "No nodes found in graph"
    found_node_ids = [node.short_id for node in g.nodes_iter()]
    extra_node_ids = set(found_node_ids) - set(node_ids)
    missing_node_ids = set(node_ids) - set(found_node_ids)
    assert len(extra_node_ids) == 0, \
        f'Extra nodes found in graph: {extra_node_ids}'
    assert len(missing_node_ids) == 0, \
        f'Missing nodes in graph: {missing_node_ids}'
    if g.fep_type in [
            constants.FEP_TYPES.SMALL_MOLECULE,
            constants.FEP_TYPES.METALLOPROTEIN,
            constants.FEP_TYPES.PROTEIN_SELECTIVITY,
            constants.FEP_TYPES.LIGAND_SELECTIVITY,
            constants.FEP_TYPES.ABSOLUTE_BINDING,
    ]:
        for node in g.nodes_iter():
            assert node.struc is not None, \
                f"graph node {node.short_id} has empty struc."
            assert node.protein_struc is None, \
                f"graph node {node.short_id} has non-empty protein_struc."
    elif g.fep_type in [
            constants.FEP_TYPES.COVALENT_LIGAND,
            constants.FEP_TYPES.PROTEIN_STABILITY
    ]:
        for node in g.nodes_iter():
            assert node.struc is not None, \
                f"graph node {node.short_id} has empty struc."
            assert node.protein_struc is not None, \
                f"graph node {node.short_id} has empty protein_struc."
def _check_fep_backend_log(log_contents: str):
    """
    Check the Desmond backend log for expected contents.
    :param log_contents: Desmond backend log file contents.
    """
    assert 'GPU Desmond' in log_contents, \
        'GPU Desmond is not detected.'
    assert '32 FEP_GPGPU' in log_contents or '64 FEP_GPGPU' in log_contents, \
        'Desmond is missing FEP_GPGPU license checkout in backend log file.'
def _check_fep_cms(cms_fname: str,
                   cts: List[structure.Structure],
                   leg: str,
                   fep_type: constants.FEP_TYPES,
                   membrane=False):
    """
    Check input and output cms files for missing components.
    Related cases: DESMOND-7328, DESMOND-8602, DESMOND-8894.
    :param cms_fname: Filename of the cms.
    :param cts: List of structures from the cms.
    :param leg: Name of the leg to check.
    :param fep_type: Type of fep used to run the job.
    :param membrane: Set to True if this is a membrane fep run.
                     Default is False.
    """
    from schrodinger.application.scisol.packages.fep.graph import Legtype
    ffio_types = Counter()
    for ct in cts:
        ffio_types[ct.property[constants.CT_TYPE]] += 1
    if fep_type in [constants.FEP_TYPES.SMALL_MOLECULE]:
        expected_ffio_types = {
            constants.CT_TYPE.VAL.FULL_SYSTEM: 1,
            constants.CT_TYPE.VAL.SOLUTE: 3 if leg == Legtype.COMPLEX.value else
                                          2,
            constants.CT_TYPE.VAL.SOLVENT: 1,
        }
        if leg in constants.FRAGMENT_LINKING_HYDRATION_JOBS:
            expected_ffio_types[constants.CT_TYPE.VAL.SOLUTE] = 1
        if membrane and leg == Legtype.COMPLEX.value:
            expected_ffio_types[constants.CT_TYPE.VAL.MEMBRANE] = 1
    elif fep_type in [
            constants.FEP_TYPES.PROTEIN_SELECTIVITY,
            constants.FEP_TYPES.LIGAND_SELECTIVITY
    ]:
        expected_ffio_types = {
            constants.CT_TYPE.VAL.FULL_SYSTEM: 1,
            constants.CT_TYPE.VAL.SOLUTE: 3 if leg == Legtype.COMPLEX.value else
                                          2,
            constants.CT_TYPE.VAL.SOLVENT: 1,
        }
        if membrane:
            expected_ffio_types[constants.CT_TYPE.VAL.MEMBRANE] = 1
    elif fep_type in [
            constants.FEP_TYPES.METALLOPROTEIN,
            constants.FEP_TYPES.COVALENT_LIGAND,
            constants.FEP_TYPES.PROTEIN_STABILITY
    ]:
        expected_ffio_types = {
            constants.CT_TYPE.VAL.FULL_SYSTEM: 1,
            constants.CT_TYPE.VAL.SOLUTE: 2,
            constants.CT_TYPE.VAL.SOLVENT: 1,
        }
        if membrane and leg == Legtype.COMPLEX.value:
            expected_ffio_types[constants.CT_TYPE.VAL.MEMBRANE] = 1
    else:
        assert False, f"Unknown fep_type: {fep_type}"
    for ffio_type, count in expected_ffio_types.items():
        assert ffio_types[ffio_type] == count, \
            f'cms file {cms_fname} does not have {count} {ffio_type} ct(s), found {ffio_types[ffio_type]} ct(s). ' \
            f'{ffio_types}'
def _check_fep_in_cfg(cfg_contents: str,
                      edge: "graph.Edge",
                      leg: str,
                      expected_lambdas=_EXPECTED_LAMBDAS):
    """
    Check the in.cfg file given an graph edge.
    :param cfg_contents: in.cfg file contents
    :param edge: Corresponding edge from the graph.
    :param leg: Name of the leg to check.
    :param expected_lambdas: Expected number of lambdas for each supported FEP type.
                             If not given, use the default for each type.
    """
    cfg = sea.Map(cfg_contents)
    expected_lambda = expected_lambdas[edge.simulation_protocol]
    if edge.is_fragment_linking and leg in constants.FRAGMENT_LINKING_HYDRATION_JOBS:
        expected_lambda = expected_lambdas[
            constants.FRAGMENT_LINKING_JOBS.VAL.FRAGMENT_HYDRATION]
    assert cfg.fep['lambda'].val == expected_lambda, \
        f'Wrong lambda schedule in in.cfg. " \
        "Found {cfg.fep["lambda"].val}, expected: {expected_lambda}'
    # Make sure macrocycle core-hopping has correct soft bond alpha
    if edge.simulation_protocol == constants.SIMULATION_PROTOCOL.MACROCYCLE_COREHOPPING:
        assert pytest.approx(
            cfg.backend['soft_bond_alpha'].val
        ) == 0.5, "Macrocycle missing soft bond alpha in in.cfg."
def _check_deltaE(deltaE_contents: str, threshold=300):
    """
    Check the deltaE file contents to make sure all of the energy
    Related cases: DESMOND-9425, DESMOND-9429.
    :param deltaE_contents: Contents of the deltaE.txt file.
    :param threshold: Maximum absolute value for the energy in kcal/mol.
    """
    vals = []
    for line in deltaE_contents.split('\n'):
        line = line.strip()
        match = REPLICA.match(line)
        if match:
            ef = float(match.group(3))
            er = float(match.group(4))
            vals.extend([ef, er])
    vals = numpy.array(vals)
    msg = "deltaE values larger than threshold"
    assert numpy.abs(
        vals.min()) < threshold, f"{msg}: |{vals.min()}| > {threshold}"
    assert numpy.abs(
        vals.max()) < threshold, f"{msg}: |{vals.max()}| > {threshold}"
def _check_fep_launcher(launcher_path: str,
                        g: "graph.Graph",
                        membrane=False,
                        deltaE_threshold=None,
                        expected_lambdas=_EXPECTED_LAMBDAS):
    """
    Check the output files generated by the FepLauncher stage.
    Related cases: DESMOND-7831, DESMOND-6582.
    :param launcher_path: Path to the output files of the fep launcher stage.
    :param g: Graph object to use for checking the output files.
    :param membrane: Set to True if this is a membrane fep run.
                     Default is False.
    :param deltaE_threshold: If not None, maximum absolute value for the deltaE energy in
                             kcal/mol.
    :param expected_lambdas: Expected number of lambdas for each supported FEP type.
                             If not given, use the default for each type.
    """
    from schrodinger.application.scisol.packages.fep.graph import Legtype
    job_name = '_'.join(os.path.basename(launcher_path).split('_')[:-1])
    for e in g.edges_iter():
        from_id, to_id = e.short_id
        fep_type = g.fep_type
        legs = [Legtype.COMPLEX.value, Legtype.SOLVENT.value]
        if e.is_fragment_linking:
            legs += list(constants.FRAGMENT_LINKING_HYDRATION_JOBS)
        for leg in legs:
            subjob_name = f'{job_name}_{from_id}_{to_id}_{leg}'
            # Check that the job completed
            multisim_log_fname = f'{os.path.join(launcher_path, subjob_name)}_multisim.log'
            standard_workups.parse_log_file(multisim_log_fname,
                                            'Multisim completed')
            with open(multisim_log_fname) as f:
                for line in f:
                    m = re.match(r'.*?stage (\d+) - lambda_hopping.*?', line)
                    if m:
                        lambda_hopping_idx = int(m.group(1))
                        break
            # Check the lambda hopping out.tgz
            lambda_hopping_tgz = os.path.join(
                launcher_path, f'{subjob_name}_{lambda_hopping_idx}-out.tgz')
            # /some/path/<JOBNAME>_8-out.tgz -> <JOBNAME>_8
            lambda_hopping_stage = os.path.basename(
                lambda_hopping_tgz.replace('-out.tgz', ''))
            with tarfile.open(lambda_hopping_tgz) as tar:
                tar_path = os.path.join(f'{subjob_name}_fep1',
                                        lambda_hopping_stage)
                if leg in constants.FRAGMENT_LINKING_HYDRATION_JOBS:
                    tar_path = lambda_hopping_stage
                basename = os.path.join(tar_path, lambda_hopping_stage)
                log_contents = tar.extractfile(basename +
                                               '.log').read().decode('latin-1')
                _check_fep_backend_log(log_contents)
                # Check the -in.cfg file
                # DESMOND-8811
                cfg_contents = tar.extractfile(
                    f'{basename}-in.cfg').read().decode('latin-1')
                _check_fep_in_cfg(cfg_contents, e, leg, expected_lambdas)
                # Check the deltaE.txt file
                # DESMOND-9445
                if deltaE_threshold is not None:
                    deltaE = tar.extractfile(
                        f'{tar_path}/deltaE.txt').read().decode('latin-1')
                    _check_deltaE(deltaE, deltaE_threshold)
                # Check cms files
                def is_cms(fname) -> bool:
                    return fname.endswith('.cms') or fname.endswith(
                        '.cmsgz') or fname.endswith('.cms.gz')
                assert len([m.name for m in tar.getmembers() if is_cms(m.name)]), \
                    f'No cms files found in {lambda_hopping_tgz}.'
                for m in tar.getmembers():
                    if is_cms(m.name):
                        cms_fname = m.name
                        cms_contents_fobj = tar.extractfile(cms_fname)
                        if m.name.endswith('gz'):
                            cms_contents_fobj = gzip.GzipFile(
                                mode='r', fileobj=cms_contents_fobj)
                        replica_cts = list(
                            structure.StructureReader.fromString(
                                cms_contents_fobj.read().decode('latin-1')))
                        _check_fep_cms(cms_fname,
                                       replica_cts,
                                       leg,
                                       fep_type,
                                       membrane=membrane)
def _check_edge_values(edge: "graph.Edge", values: Dict[str, object]):
    """
    Check values associated with the graph edge.
    :param edge: Edge to check
    :param values: Dictionary with the keys 'simulation_protocol',
                   'ddg', 'ddg_err'. The values in the edge are compared
                   with the corresponding values here.
    """
    # Check the simulation protocol
    if values.get('simulation_protocol'):
        assert edge.simulation_protocol == values['simulation_protocol'], \
            f"Edge {edge.short_id} has incorrect simulation protocol. " \
            f"Found {edge.simulation_protocol}, expected {values['simulation_protocol']}"
    # Check free energy values.
    if values.get('ddg'):
        assert pytest.approx(edge.bennett_ddg.val, abs=0.5) == values['ddg'], \
            f'{edge.short_id} bennett_ddg.val different. ' \
            f'Found {edge.bennett_ddg.val}, expected {values["ddg"]}.'
        assert pytest.approx(edge.bennett_ddg.unc, abs=0.5) == values['ddg_err'], \
            f'{edge.short_id} bennett_ddg.unc different. ' \
            f'Found {edge.bennett_ddg.unc}, expected {values["ddg_err"]}.'
    assert edge.bennett_ddg is not None, f'{edge.short_id} bennett_ddg is None'
    assert edge.ccc_ddg is not None, f'{edge.short_id} ccc_ddg is None'
[docs]def desmond_check_fep_results(
        fmp_fname: str,
        launcher_path: str,
        fep_type: constants.FEP_TYPES,
        expected: Dict[Tuple[str, str], Dict[str, object]],
        membrane: bool = False,
        min_frames: int = 21,
        skip_parched: bool = False,
        skip_sid: bool = False,
        edges_to_skip_parched_check: Optional[List[Tuple[str, str]]] = None,
        deltaE_threshold: Optional[float] = None,
        expected_lambdas=_EXPECTED_LAMBDAS):
    """
    Check the fep results given the fep type and an `expected`
    a dictionary mapping expected edge short ids to their
    corresponding attributes.
    :param fmp_fname: Name of the output fmp to check.
    :param launcher_path: Path containing the output of the FepLauncher stage.
    :param fep_type: Type of fep used to run the job.
    :param expected: Dictionary mapping edge short ids to
                     a dictionary of attributes to check.
                     The attributes are the simulation_protocol,
                     with the values `constants.SIMULATION_PROTOCOL`,
                     'ddg', 'ddg_err' with the values corresponding to
                     the free energy. If the attributes are not present,
                     the corresponding checks are skipped.
    :param membrane: Set to True if this is a membrane fep run.
                     Default is False.
    :param min_frames: Minimum number of frames in the parched trajectories.
                       Default is 21 for the number of frames in a
                       standard 5 ns FEP job.
    :param skip_parched: If True, skip checking the parched trajectories.
                         Default is False.
    :param skip_sid: If True, skip checking for sid analysis. Default is False.
    :param edges_to_skip_parched_check: A list of short edge ids for which to
                                        skip checking for parched trajectory
                                        when skip_parched is False. If
                                        skip_parched is True then all edges
                                        will be skipped.
    :param deltaE_threshold: If not None, maximum absolute value for the deltaE energy in
                             kcal/mol.
    :param expected_lambdas: Expected number of lambdas for each supported FEP type.
                             If not given, use the default for each type.
    """
    from schrodinger.application.scisol.packages.fep.graph import Graph
    g = Graph.deserialize(fmp_fname)
    id2edge = {e.short_id: e for e in g.edges_iter()}
    edge_ids = set(expected.keys())
    node_ids = {a[0] for a in edge_ids}.union({a[1] for a in edge_ids})
    missing_in_graph = edge_ids - set(id2edge.keys())
    additional_in_graph = set(id2edge.keys()) - edge_ids
    assert len(missing_in_graph) == 0, \
        
f"Missing edges in graph: {missing_in_graph}"
    assert len(additional_in_graph) == 0, \
        
f"Additional edges in graph: {additional_in_graph}"
    # Check edges.
    for edge_id, values in expected.items():
        e = id2edge.get(edge_id)
        # Check that edge is in the graph.
        assert e is not None, f"{edge_id} not in {fmp_fname}."
        _check_edge_values(e, values)
        if not skip_sid:
            expected_num_replicas = int(
                expected_lambdas[e.simulation_protocol].split(':')[-1])
            _check_edge_sid(e, expected_num_replicas, expected_num_replicas)
        edges_to_skip_parched_check = edges_to_skip_parched_check or []
        if not skip_parched and edge_id not in edges_to_skip_parched_check:
            _check_edge_parched(e, min_frames)
    _check_graph_environment(g, membrane)
    _check_graph_nodes(g, node_ids)
    _check_fep_launcher(launcher_path, g, membrane, deltaE_threshold,
                        expected_lambdas)
    return True 
[docs]def desmond_check_ab_fep_results(
    fmp_fname: str,
    expected_values: Dict[Tuple[str, str], Dict[str, object]],
    min_frames: int,
    expected_lambdas_complex: int,
    expected_lambdas_solvent: int,
    membrane: bool = False,
):
    from schrodinger.application.scisol.packages.fep.graph import Graph
    g = Graph.deserialize(fmp_fname)
    id2edge = {e.short_id: e for e in g.edges_iter()}
    edge_ids = set(expected_values.keys())
    node_ids = {a[0] for a in edge_ids}.union({a[1] for a in edge_ids})
    missing_in_graph = edge_ids - set(id2edge.keys())
    additional_in_graph = set(id2edge.keys()) - edge_ids
    assert len(missing_in_graph) == 0, \
        
f"Missing edges in graph: {missing_in_graph}"
    assert len(additional_in_graph) == 0, \
        
f"Additional edges in graph: {additional_in_graph}"
    # Check edges.
    for edge_id, values in expected_values.items():
        e = id2edge.get(edge_id)
        # Check that edge is in the graph.
        assert e is not None, f"{edge_id} not in {fmp_fname}."
        _check_edge_values(e, values)
        _check_edge_sid(e, expected_lambdas_complex, expected_lambdas_solvent)
        _check_edge_parched(e, min_frames)
    _check_graph_environment(g, membrane)
    _check_graph_nodes(g, node_ids)
    return True 
[docs]def desmond_check_memory_usage(logfile: str,
                               cpu_limits: List[float],
                               gpu_limits: Optional[List[float]] = None):
    """
    Get statistics on memory usage printed to desmond logfile and compare
    to specified inputs.
    :param logfile: The logfile containing the memory usages
    :param cpu_limits: The maximum acceptable value for the mean and maximum CPU
    memory usage in kB
    :param gpu_limits: The maximum acceptable value for the mean and maximum GPU
    memory usage in kB [optional]
    """
    types_to_limits = {'CPU': cpu_limits}
    if gpu_limits:
        types_to_limits['GPU'] = gpu_limits
    memory_usages = defaultdict(list)
    with open(logfile) as f:
        for line in f.readlines():
            match = re.search(r'using (\d+) kB of (.*?) memory', line)
            if match is not None:
                usage = match.group(1)
                memory_usages[match.group(2)].append(int(usage))
    def check_limits(usages, limits, memtype):
        if not usages:
            raise ValueError(
                f"No {memtype} memory usage statements found in log "
                f"file")
        usages_arr = numpy.array(usages)
        mean_usage = usages_arr.mean()
        max_usage = usages_arr.max()
        mean_limit, max_limit = limits
        print(f"{memtype} usage mean: {mean_usage}, max {max_usage}")
        return mean_usage < mean_limit and max_usage < max_limit
    return all([
        check_limits(memory_usages[memtype], limits, memtype)
        for memtype, limits in types_to_limits.items()
    ]) 
[docs]def custom_charge_ct_count(input_mae_fname: str) -> int:
    """
    Return the count of cts that have a custom charge block.
    """
    count = 0
    with open(input_mae_fname, 'r') as f:
        for l in f.readlines():
            if 'mmffld_custom_prop' in l:
                count += 1
    return count 
if __name__ == "__main__":
    def extract_summary(string):
        summary = ""
        for line in string.split("\n"):
            line = line.strip()
            if line:
                summary += " " + line
            elif summary:
                return summary.strip()
    if '-h' in sys.argv and len(sys.argv) > 2:
        for arg in sys.argv[2:]:
            if arg in dir():
                print(arg + ":")
                print(eval(arg).__doc__)
    elif len(sys.argv) > 2 and sys.argv[1] in dir():
        if eval(sys.argv[1])(sys.argv[2:]):
            print("  Success!")
        else:
            print("  Failure.")
    else:
        print("usage: %s <workup> <args>\n" % os.path.split(__file__)[-1])
        print(extract_summary(__doc__), "\n")
        print("available workups:")
        for method in dir():
            if method.startswith("desmond"):
                method = eval(method)
                description = method.__doc__
                summary = extract_summary(description)
                print(f"  {method.__name__}: {summary}")