"""
Utilities for Boltzmann averaging properties.
Copyright Schrodinger, LLC. All rights reserved.
"""
from schrodinger.application.matsci import msutils
from schrodinger.application.matsci import reaction_workflow_utils as rxnwfu
[docs]def validate_energy_property_names(energy_property_names):
    """
    Validate the energy property names.
    :type energy_property_names: list
    :param energy_property_names: the energy property names to validate
    :raise ValueError: if there is an issue
    """
    for energy_property_name in energy_property_names:
        if not rxnwfu.ReactionWorkflowEnergyAnalysis.getKcalPerMolConversion(
                energy_property_name):
            raise ValueError(f'The energy property name {energy_property_name} '
                             'is missing units.') 
[docs]def get_rxnwf_structures(conformer_dict, allow_sibling_groups=False):
    """
    Return structures prepared for a reaction workflow.
    :type conformer_dict: dict
    :param conformer_dict: binned conformers, keys are str that are conformer
        hashes, values are list[`schrodinger.structure.Structure`]
    :type allow_sibling_groups: bool
    :param allow_sibling_groups: whether to allow sibling groups
    :rtype: list[`schrodinger.structure.Structure`]
    :return: structures prepared for a reaction workflow
    """
    # if rxnwf input optionally redefine sibling groups so that each sibling
    # group has only a single conformer group as opposed to multiple conformer
    # groups, if a non-rxnwf input then temporarily make it a rxnwf input
    # in which each sibling group contains a single conformer group
    sts = []
    for idx, conformers in enumerate(conformer_dict.values(), 1):
        for conformer in conformers:
            _conformer = conformer.copy()
            if _conformer.property.get(rxnwfu.REACTION_WF_STRUCTURE_KEY):
                if not allow_sibling_groups:
                    _conformer.property[
                        rxnwfu.SIBLING_GROUP_KEY] = f'sibling_group_{idx}'
                    _conformer.property.pop(rxnwfu.PARENT_SIBLING_GROUPS_KEY,
                                            None)
            else:
                _conformer.property[rxnwfu.REACTION_WF_STRUCTURE_KEY] = True
                _conformer.property[
                    rxnwfu.SIBLING_GROUP_KEY] = f'sibling_group_{idx}'
                _conformer.property[
                    rxnwfu.CONFORMERS_GROUP_KEY] = f'conformer_group_{idx}'
            sts.append(_conformer)
    return sts 
[docs]def get_averaged_properties(sts,
                            energy_property_names,
                            temps=None,
                            allow_sibling_groups=False,
                            atomic=False,
                            only_lowest_energy=False):
    """
    Return averaged properties.
    :type sts: list[`schrodinger.structure.Structure`]
    :param sts: the structures whose properties will be averaged
    :type energy_property_names: list
    :param energy_property_names: the energy property names to use for
        averaging
    :type temps: list
    :param temps: temperatures in K, only used for temperature independent
        property keys
    :type allow_sibling_groups: bool
    :param allow_sibling_groups: whether to average over sibling groups
    :type atomic: bool
    :param atomic: whether to also average atomic properties
    :type only_lowest_energy: bool
    :param only_lowest_energy: Use only the lowest energy conformer rather than
        averaging over conformers
    :raise ValueError: if there is an issue
    :rtype: list[`reaction_workflow_utils.EnergyAnalysisProperty`]
    :return: the averaged properties
    """
    st_property_names = msutils.get_common_float_property_names(sts)
    if not st_property_names:
        raise ValueError('The are no common structure properties defined among '
                         'the given structures.')
    if not set(energy_property_names).issubset(st_property_names):
        raise ValueError('At least one of the given energy properties is not '
                         'defined for all of the given structures.')
    if allow_sibling_groups and atomic:
        raise ValueError('Atomic properties for sibling groups is not '
                         'supported.')
    validate_energy_property_names(energy_property_names)
    conformer_dict = get_conformer_dict(sts)
    sts = get_rxnwf_structures(conformer_dict,
                               allow_sibling_groups=allow_sibling_groups)
    ea = rxnwfu.ReactionWorkflowEnergyAnalysis(sts, energy_property_names)
    all_properties = []
    for st_property_name in st_property_names:
        try:
            st_properties = ea.getProperties(
                include_x_terms=False,
                only_lowest_energy=only_lowest_energy,
                property_key=st_property_name,
                atomic=False,
                temps=temps)
        except rxnwfu.ReactionWorkflowException as err:
            raise ValueError(str(err))
        all_properties.extend(st_properties)
    if not atomic:
        return all_properties
    atom_property_names = msutils.get_common_float_atom_property_names(sts)
    if not atom_property_names:
        raise ValueError('The are no common atom properties defined among '
                         'the given structures.')
    for atom_property_name in atom_property_names:
        try:
            atom_properties = ea.getProperties(
                include_x_terms=False,
                only_lowest_energy=only_lowest_energy,
                property_key=atom_property_name,
                atomic=True,
                temps=temps)
        except rxnwfu.ReactionWorkflowException as err:
            raise ValueError(str(err))
        all_properties.extend(atom_properties)
    return all_properties 
[docs]def get_averaged_value(eprop, averaged_properties, prop=None):
    """
    Get value of the avergaed property using energy property.
    :param str eprop: Energy property
    :param list[rxnwfu.EnergyAnalysisProperty] averaged_properties: List of
        avergaed properties
    :type prop: str or None
    :param prop: Whether to use property or energy property
    :rtype: float
    :return: Averaged value
    :raise AssertionError: If energy property (and property if present) is not
        found (shouldn't happen)
    """
    prop = eprop if prop is None else prop
    for bprop in averaged_properties:
        if bprop.energy_key == eprop and bprop.property_key == prop:
            return bprop.ensemble[0]
    raise AssertionError('%s not found in %s' % (eprop, averaged_properties)) 
[docs]def get_representatives(sts,
                        energy_property_name,
                        temps=None,
                        atomic=False,
                        only_lowest_energy=False):
    """
    Return representative structures marked with average properties.
    :type sts: list[`schrodinger.structure.Structure`]
    :param sts: the structures whose properties will be averaged
    :type energy_property_name: str
    :param energy_property_name: the energy property name to use for
        averaging
    :type temps: list
    :param temps: temperatures in K, only used for temperature independent
        property keys
    :type atomic: bool
    :param atomic: whether to also average atomic properties
    :type only_lowest_energy: bool
    :param only_lowest_energy: Use only the lowest energy conformer rather than
        averaging over conformers
    :raise ValueError: if there is an issue
    :rtype: list[`schrodinger.structure.Structure`]
    :return: the representatives
    """
    # track whether this is rxnwf input
    rep_st = sts[0]
    is_rxnwf = rep_st.property.get(rxnwfu.REACTION_WF_STRUCTURE_KEY)
    properties = get_averaged_properties(sts, [energy_property_name],
                                         temps=temps,
                                         allow_sibling_groups=False,
                                         atomic=atomic,
                                         only_lowest_energy=only_lowest_energy)
    st_properties = {}
    atom_properties = {}
    for aproperty in properties:
        st = aproperty.representative_conformers[0]
        avg_property_key = aproperty.avg_property_key
        avg_property_value = aproperty.ensemble[0]
        if aproperty.atom_idx:
            # some atom properties like r_m_charge1 are core properties and
            # can not be removed
            try:
                st.atom[aproperty.atom_idx].property.pop(
                    aproperty.property_key, None)
            except ValueError:
                pass
            atom_properties.setdefault(st, {}).setdefault(
                aproperty.atom_idx, {})[avg_property_key] = avg_property_value
        else:
            st.property.pop(aproperty.property_key, None)
            # if not a rxnwf input then the following were temporarily created so
            # remove them here
            if not is_rxnwf:
                st.property.pop(rxnwfu.REACTION_WF_STRUCTURE_KEY, None)
                st.property.pop(rxnwfu.SIBLING_GROUP_KEY, None)
                st.property.pop(rxnwfu.CONFORMERS_GROUP_KEY, None)
            st_properties.setdefault(st,
                                     {})[avg_property_key] = avg_property_value
    representatives = []
    for st, sprops in st_properties.items():
        st.property.update(sprops)
        for atom_idx, aprops in atom_properties.get(st, {}).items():
            st.atom[atom_idx].property.update(aprops)
        representatives.append(st)
    return representatives