"""
Utilities for order parameter analysis.
Copyright Schrodinger, LLC. All rights reserved.
"""
import copy
import json
import os
import warnings
from collections import namedtuple
from collections import defaultdict
import numpy
from schrodinger.application.desmond import cms
from schrodinger.application.desmond.packages import analysis
from schrodinger.application.desmond.packages import staf
from schrodinger.application.desmond.packages import topo
from schrodinger.application.desmond.packages import traj
from schrodinger.application.matsci import jobutils
from schrodinger.application.matsci import parserutils
from schrodinger.application.matsci import msprops
from schrodinger.application.matsci import textlogger
from schrodinger.structutils import analyze
from schrodinger.structutils import transform
from schrodinger.utils import fileutils
# see SHARED-4320 which is an export to excel formatting issue
with warnings.catch_warnings():
    warnings.simplefilter('ignore', UserWarning)
    import pandas
ORDER_PARAMETER_OUTPUT_FILES_PROP = 's_matsci_order_parameter_output_files'
ORDER_PARAMETER_OUTPUT_FILES_SEP = ','
RESULTS_FILE_EXT = 'csv'
OUT_FILE_TAG = 'op'
PER_MOL_OUT_FILE_TAG = f'per_mol_{OUT_FILE_TAG}'
MOL_TAG = '_mol_'
INDEX_NAME = 'Frame'
Result = namedtuple('Result', ['total_op', 'per_mol_op'])
PerMolOP = namedtuple('PerMolOP', ['mol_numbers', 'mol_ops'])
[docs]def log(msg, **kwargs):
    """
    Add a message to the log file
    :param str msg: The message to log
    Additional keyword arguments are passed to the textlogger.log_msg function
    """
    textlogger.log_msg(msg, **kwargs) 
[docs]def get_name_w_mol_n(name, mol_n):
    """
    Return the name with the molecule number.
    :type name: str
    :param name: the name
    :type mol_n: int
    :param mol_n: the molecule number
    :rtype: str
    :return: the name with the molecule number
    """
    return f'{name}{MOL_TAG}{mol_n}' 
[docs]def get_name_wo_mol_n(name):
    """
    Return the name without the molecule number.
    :type name: str
    :param name: the name
    :rtype: (str, int) or (str, None)
    :return: the name without the molecule number and
        the molecule number if there is one
    """
    if MOL_TAG in name:
        name, mol_n = name.split(MOL_TAG)
        return name, int(mol_n)
    return name, None 
[docs]class ReduceVecMixin:
    """
    Manages reducing vectors.
    """
[docs]    def reduce_vec(self, n, m):
        """
        Specify how to reduce the reference director vector, n, and all of the
        description vectors, m, into an order parameter.
        :type n: numpy.array
        :param n: the reference director vector (row vector of length 3)
        :type m: numpy.array
        :param m: the description vectors (matrix of size N X 3), where N is
            the number of molecules
        :rtype: float
        :return: the order parameter
        """
        # molecule numbers can't be cached here because they are dynamic
        mol_numbers = getattr(self, '_mol_numbers', set())
        if not mol_numbers:
            for aid in self._aids:
                mol_numbers.add(self._cms_model.atom[aid].getMolecule().number)
        mol_numbers = sorted(mol_numbers)
        self.per_mol_op = getattr(self, 'per_mol_op', [])
        # need to shape the incoming array
        n.shape = (1, 3)
        mol_ops = [
            analysis.reduce_vec(n, numpy.array([mol_m])) / len(m) for mol_m in m
        ]
        self.per_mol_op.append(
            PerMolOP(mol_numbers=mol_numbers, mol_ops=mol_ops))
        return analysis.reduce_vec(n, m)  
[docs]def get_trj_from_cms_file(cms_file):
    """
    Return the trajectory from the given .cms file.
    :type cms_file: str
    :param cms_file: the .cms file (can be a path)
    :rtype: list
    :return: contains the trajectory as a list of
        schrodinger.application.desmond.packages.traj.Frame
    """
    path = os.path.split(cms_file)[0]
    cms_obj = cms.Cms(file=cms_file)
    trj_dir = parserutils.get_trj_dir_name(cms_obj.fsys_ct)
    trj_dir = os.path.join(path, trj_dir)
    return traj.read_traj(trj_dir) 
[docs]class AslException(Exception):
    pass 
[docs]class SmartsException(Exception):
    pass 
[docs]class DipoleDirector(analysis.DipoleDirector, ReduceVecMixin):
    """
    Manage a dipole director.
    """
    def _dyninit(self):
        """
        A method that is automatically called after the atomic ids are
        updated and allows for dynamically redefining geometry
        calculations (per frame) for cases where the ASL is geometry dependent,
        i.e. contains a within or beyond argument.
        :raise AslException: if there is an issue with the ASL
        """
        if not set(self._aids):
            msg = ('The given ASL of {asl} does not match any atoms.')
            raise AslException(msg.format(asl=self._asl))
        super()._dyninit() 
[docs]class MomentOfInertiaDirector(analysis.MomentOfInertiaDirector, ReduceVecMixin):
    """
    Manage a moment of inertia director.
    """
    def _dyninit(self):
        """
        A method that is automatically called after the atomic ids are
        updated and allows for dynamically redefining geometry
        calculations (per frame) for cases where the ASL is geometry dependent,
        i.e. contains a within or beyond argument.
        :raise AslException: if there is an issue with the ASL
        """
        if not set(self._aids):
            msg = ('The given ASL of {asl} does not match any atoms.')
            raise AslException(msg.format(asl=self._asl))
        super()._dyninit() 
[docs]class UniqueSmartsDirector(analysis.SmartsDirector, ReduceVecMixin):
    """
    Manage a unique SMARTS director.
    """
[docs]    def __init__(self, msys_model, cms_model, asl, smarts):
        """
        Create an instance.
        :type msys_model: Desmond msys .System
        :param msys_model: the msys object (from msys.LoadMAE)
        :type cms_model: schrodinger.application.desmond.cms.Cms
        :param cms_model: the cms object
        :type asl: str
        :param asl: the ASL for the descriptor
        :type smarts: str
        :param smarts: the SMARTS for the descriptor
        """
        self._smarts = smarts
        # note the following atom_pairs is a dict with completely arbitrary ordering
        self._atom_pairs = analyze.evaluate_smarts_by_molecule(
            cms_model.fsys_ct,
            self._smarts,
            uniqueFilter=True,
            matches_by_mol=True)
        staf.CompositeDynamicAslAnalyzer.__init__(self, msys_model, cms_model,
                                                  asl) 
    def _dyninit(self):
        """
        A method that is automatically called after the atomic ids are
        updated and allows for dynamically redefining geometry
        calculations (per frame) for cases where the ASL is geometry dependent,
        i.e. contains a within or beyond argument.
        :raise AslException: if there is an issue with the ASL
        :raise SmartsException: if there is an issue with the SMARTS
        """
        if not set(self._aids):
            msg = ('The given ASL of {asl} does not match any atoms.')
            raise AslException(msg.format(asl=self._asl))
        # use Canvas SMARTS even though it is slower than mmpatty SMARTS,
        # because it supports component-level SMARTS, i.e. containing '.',
        # as here we need to support matching nonbonded pairs within the same
        # molecule
        #
        # note that in some applications the actual direction of the descriptor
        # could matter little as the individual order parameter values are
        # meaningless (only relative order parameters have meaning), in other
        # applications it is expected that individual order parameters have
        # meaning, for example as in 1.0 pointing along the reference director,
        # 0.0 pointing orthogonal to it, and -1.0 pointing anti to it, the
        # uniquification here ensure this
        self._analyzers = []
        self._mol_numbers = []
        aids = set(self._aids)
        fsys_ct = self._cms_model.fsys_ct
        if not self._atom_pairs:
            msg = ('The given SMARTS of {smarts} does not match any atoms.')
            raise SmartsException(msg.format(smarts=self._smarts))
        self._match_len = len(list(self._atom_pairs.values())[0][0])
        if self._match_len not in [2, 3]:
            msg = (
                'The given SMARTS of {smarts} is not for pairs or triples of atoms.'
            )
            raise SmartsException(msg.format(smarts=self._smarts))
        unique_match_found = False
        for matches in self._atom_pairs.values():
            # in order to unambiguously reduce the description vector only a single
            # match per molecule is allowed
            if len(matches) != 1:
                continue
            else:
                unique_match_found = True
                match = matches[0]
            if set(match).issubset(aids):
                root = match[0]
                analyzers = [
                    analysis.Vector(self._msys_model, self._cms_model, root,
                                    match[1])
                ]
                if self._match_len == 3:
                    analyzers.append(
                        analysis.Vector(self._msys_model, self._cms_model, root,
                                        match[2]))
                self._analyzers.extend(analyzers)
                self._mol_numbers.append(
                    fsys_ct.atom[root].getMolecule().number)
        if not self._analyzers:
            if not unique_match_found:
                msg = (
                    'The given SMARTS of {smarts} does not uniquely match any of '
                    'the molecules.')
                raise SmartsException(msg.format(smarts=self._smarts))
            else:
                msg = ('The given SMARTS of {smarts} does not overlap with the '
                       'given ASL of {asl}.')
                raise SmartsException(
                    msg.format(smarts=self._smarts, asl=self._asl))
[docs]    def reduce_vec(self, n, m):
        """
        Specify how to reduce the reference director vector, n, and all of the
        description vectors, m, into an order parameter.
        :type n: numpy.array
        :param n: the reference director vector (row vector of length 3)
        :type m: numpy.array
        :param m: the description vectors (matrix of size N X 3), where N is
            the number of molecules
        :rtype: float
        :return: the order parameter
        """
        # m has dimensions of the number of molecules (matches) or twice that depending
        # on the match length
        if self._match_len == 3:
            new_m = []
            for idx in range(0, len(m), 2):
                normal = transform.get_normalized_vector(
                    numpy.cross(m[idx], m[idx + 1]))
                new_m.append(normal)
            m = numpy.array(new_m)
        return ReduceVecMixin.reduce_vec(self, n, m)  
[docs]class SmartsDirector(analysis.SmartsDirector, ReduceVecMixin):
    """
    Manage a SMARTS director.
    """
    def _dyninit(self):
        """
        A method that is automatically called after the atomic ids are
        updated and allows for dynamically redefining geometry
        calculations (per frame) for cases where the ASL is geometry dependent,
        i.e. contains a within or beyond argument.
        :raise AslException: if there is an issue with the ASL
        :raise SmartsException: if there is an issue with the SMARTS
        """
        if not set(self._aids):
            msg = ('The given ASL of {asl} does not match any atoms.')
            raise AslException(msg.format(asl=self._asl))
        self._analyzers = []
        self._mol_numbers = []
        aids = set(self._aids)
        fsys_ct = self._cms_model.fsys_ct
        if not self._atom_pairs:
            msg = ('The given SMARTS of {smarts} does not match any atoms.')
            raise SmartsException(msg.format(smarts=self._smarts))
        match_len = len(self._atom_pairs[0])
        if match_len != 2:
            msg = ('The given SMARTS of {smarts} is not for bonding pairs '
                   'of atoms.')
            raise SmartsException(msg.format(smarts=self._smarts))
        for a1, a2 in self._atom_pairs:
            if a1 in aids and a2 in aids:
                self._analyzers.append(
                    analysis.Vector(self._msys_model, self._cms_model, a1, a2))
                self._mol_numbers.append(fsys_ct.atom[a1].getMolecule().number) 
[docs]class Descriptor:
    """
    Manage a descriptor.
    """
    DIPOLE = 'dipole'
    MOMENT_OF_INERTIA = 'moment_of_inertia'
    SMARTS_NONUNIQUE_BONDS = 'SMARTS_nonunique_bonds'
    SMARTS_UNIQUE_PAIR = 'SMARTS_unique_pair'
    SMARTS_UNIQUE_TRIPLE_NORMAL = 'SMARTS_unique_triple_normal'
    TYPES_TO_CLASSES = {
        DIPOLE: DipoleDirector,
        MOMENT_OF_INERTIA: MomentOfInertiaDirector,
        SMARTS_NONUNIQUE_BONDS: SmartsDirector,
        SMARTS_UNIQUE_PAIR: UniqueSmartsDirector,
        SMARTS_UNIQUE_TRIPLE_NORMAL: UniqueSmartsDirector
    }
    _SMARTS_MATCH_DICT = {
        SMARTS_NONUNIQUE_BONDS: 2,
        SMARTS_UNIQUE_PAIR: 2,
        SMARTS_UNIQUE_TRIPLE_NORMAL: 3
    }
    GROUP_KWARG = 'group'
    ASL_KWARG = 'asl'
    ATYPE_KWARG = 'atype'
    SMARTS_KWARG = 'smarts'
[docs]    def __init__(self, name, group=None, asl=None, atype=None, smarts=None):
        """
        Create an instance.
        :type name: str
        :param name: the name of the descriptor
        :type group: str or None
        :param group: the group of the descriptor or
            None if there isn't one
        :type asl: str
        :param asl: the ASL for the descriptor or
            None if there isn't one
        :type atype: str
        :param atype: the type of descriptor or
            None if there isn't one
        :type smarts: str or None
        :param smarts: the SMARTS for the descriptor or
            None if there isn't one
        """
        self.name = name
        self.group = group
        self.asl = asl
        self.atype = atype
        self.smarts = smarts
        self.descriptor = None 
[docs]    def getDescriptor(self, msys_obj, cms_obj):
        """
        Get the descriptor.
        :type msys_obj: Desmond msys .System
        :param msys_obj: the msys object (from msys.LoadMAE)
        :type cms_obj: schrodinger.application.desmond.cms.Cms
        :param cms_obj: the cms object
        :rtype: schrodinger.application.desmond.packages.staf.CompositeDynamicAslAnalyzer
        :return: descriptor subclasses of the given type
        """
        aclass = self.TYPES_TO_CLASSES[self.atype]
        args = (msys_obj, cms_obj, self.asl)
        if self.atype in self._SMARTS_MATCH_DICT:
            args += (self.smarts,)
        self.descriptor = aclass(*args)
        return self.descriptor 
[docs]    def getFileName(self, basename):
        """
        Return a file name for this descriptor using the given base name.
        :type basename: str
        :param basename: base name to use in naming the file
        :rtype: str
        :return: the descriptor file name
        """
        return '.'.join([basename, OUT_FILE_TAG, self.group, RESULTS_FILE_EXT]) 
[docs]    def getPerMolFileName(self, basename):
        """
        Return a per molecule file name for this descriptor using the given
        base name.
        :type basename: str
        :param basename: base name to use in naming the file
        :rtype: str
        :return: the descriptor per molecule file name
        """
        return '.'.join(
            [basename, PER_MOL_OUT_FILE_TAG, self.group, RESULTS_FILE_EXT])  
[docs]def get_descriptors_from_file(descriptors_file):
    """
    Return a list of descriptors from the given
    descriptors file.
    :type descriptors_file: str
    :param descriptors_file: .json file containing specifications
        for descriptors, i.e. ways to determine vectors used in
        computing the order parameters with respect to the director,
        a specification includes information like name, group, ASL,
        type, and SMARTS (can be a path)
    :rtype: list
    :return: contains Descriptor
    """
    # descriptor files are formatted like
    #
    # {
    #     "some_name_1":{
    #         "group":"some_group",
    #         "asl":"some_ASL",
    #         "atype":"some_type",
    #         "smarts":"some_SMARTS'
    #     },
    #     "some_name_2": ...
    # }
    with open(descriptors_file, 'r') as afile:
        adict = json.load(afile, object_pairs_hook=dict)
    return [Descriptor(name=name, **kwargs) for name, kwargs in adict.items()] 
[docs]class Director(staf.GeomAnalyzerBase):
    """
    Manage a director.
    """
[docs]    def __init__(self, vec):
        """
        Create an instance.
        :type vec: numpy.array
        :param vec: the director vector
        """
        self._result = vec  
[docs]class OrderParameter:
    """
    Manage order parameter analysis.
    """
[docs]    def __init__(self,
                 cms_file,
                 director_abc_coeffs,
                 descriptors_file,
                 logger=None):
        """
        Create an instance.
        :type cms_file: str
        :param cms_file: the .cms file for the simulation on which
            to run the order parameter analysis (can be a path)
        :type director_abc_coeffs: tuple
        :param director_abc_coeffs: coefficients of the static reference
            director vector in the lattice vector basis, for example (0, 0, 1)
            for the c-lattice vector or z-axis of a cubic cell
        :type descriptors_file: str
        :param descriptors_file: .json file containing specifications
            for descriptors, i.e. ways to determine vectors used in
            computing the order parameters with respect to the director,
            a specification includes information like name, group, ASL,
            type, and SMARTS (can be a path) (see get_descriptors_from_file
            for more information)
        :type logger: logging.Logger or None
        :param logger: output logger or None if there isn't one
        """
        self.cms_file = cms_file
        self.director_abc_coeffs = director_abc_coeffs
        self.descriptors_file = descriptors_file
        self.logger = logger
        self.basename = jobutils.get_jobname(
            fileutils.get_basename(self.cms_file))
        self.msys_obj = None
        self.cms_obj = None
        self.director = None
        self.trajectory = None
        self.descriptors = None
        self.results = None 
[docs]    @staticmethod
    def getDirector(cms_obj, director_abc_coeffs):
        """
        Return the unit director vector object for the given cms
        (in Angstrom).
        :type cms_obj: schrodinger.application.desmond.cms.Cms
        :param cms_obj: the cms object
        :type director_abc_coeffs: tuple
        :param director_abc_coeffs: coefficients of the static reference
            director vector in the lattice vector basis, for example (0, 0, 1)
            for the c-lattice vector or z-axis of a cubic cell
        :rtype: Director
        :return: the unit director vector object
        """
        chorus = cms.get_box(cms_obj.fsys_ct)
        step = 3
        vecs = [
            numpy.array(chorus[i:i + step])
            for i in range(0, len(chorus), step)
        ]
        director = numpy.zeros(3)
        for coeff, vec in zip(director_abc_coeffs, vecs):
            director += coeff * vec
        return Director(transform.get_normalized_vector(director)) 
    def _getResults(self):
        """
        Return the order parameter results.
        :rtype: list[Result]
        :return: contains a Result for each descriptor
        """
        # order parameters are calculated using the following equation
        #
        # S_k = (1/N) sum_i=1^N [3(M dot m_i^k)^2 - 1]/2
        #
        # where k is frame index, N is the number of descriptor vectors (typically
        # the number of molecules if there is one descriptor per molecule), M is
        # the static director reference vector, for example like the z-axis, m is
        # a descriptor vector, for example dipole, head-to-tail vector, etc., S is
        # the order parameter in units of Angstrom^4
        analyses = []
        for descriptor in self.descriptors:
            d = descriptor.getDescriptor(self.msys_obj, self.cms_obj)
            analyses.append(
                analysis.OrderParameter(self.director, d, d.reduce_vec))
        results = analysis.analyze(self.trajectory, *analyses)
        if len(analyses) == 1:
            results = [results]
        all_results = []
        for result, _analysis in zip(results, analyses):
            all_results.append(
                Result(total_op=result, per_mol_op=_analysis._vec2.per_mol_op))
        return all_results
    def _writeResultFiles(self):
        """
        Write the result files.
        :rtype: list
        :return: contains the file names of the written files
        """
        # there two types of files, those with total order parameters
        # per frame and those with molecular order parameters per frame,
        # each can have multiple types of order parameters, collect the
        # types of order parameters and their total or molecular values
        # per file name in order to prepare csv output files
        results_by_file = defaultdict(list)
        for descriptor, result in zip(self.descriptors, self.results):
            # total
            file_name = descriptor.getFileName(self.basename)
            results_by_file[file_name].append(
                (descriptor.name, result.total_op))
            # molecular
            file_name = descriptor.getPerMolFileName(self.basename)
            results_by_mol_number = defaultdict(list)
            for per_mol_op in result.per_mol_op:
                for mol_number, mol_op in zip(per_mol_op.mol_numbers,
                                              per_mol_op.mol_ops):
                    results_by_mol_number[mol_number].append(mol_op)
            for mol_number, mol_ops in results_by_mol_number.items():
                results_by_file[file_name].append(
                    (get_name_w_mol_n(descriptor.name, mol_number), mol_ops))
        for afile, names_results in results_by_file.items():
            names, results = list(zip(*names_results))
            results = list(zip(*results))
            df = pandas.DataFrame.from_records(data=results, columns=names)
            df.index += 1
            df.index.name = INDEX_NAME
            df.to_csv(afile)
            jobutils.add_outfile_to_backend(afile)
        return list(results_by_file)
    def _writeOutputCMS(self, files):
        """
        Write the output cms file.
        :type files: list
        :param files: contains the file names of the order parameter
            result files
        """
        cms_obj = copy.copy(self.cms_obj)
        for prop in (msprops.ORIGINAL_CMS_PROP, msprops.TRAJECTORY_FILE_PROP):
            cms_obj.remove_cts_property(prop)
        files_str = ORDER_PARAMETER_OUTPUT_FILES_SEP.join(files)
        cms_obj.set_cts_property(ORDER_PARAMETER_OUTPUT_FILES_PROP, files_str)
        jobutils.set_source_path(cms_obj)
        cms_out_file = '.'.join([self.basename, OUT_FILE_TAG, 'cms'])
        jobutils.write_cms_with_wam(
            cms_obj,
            cms_out_file,
            wam_type=jobutils.WAM_TYPES.MS_ORDER_PARAMETER)
        jobutils.add_outfile_to_backend(cms_out_file, set_structure_output=True)
[docs]    def run(self):
        """
        Run the order parameter analysis.
        """
        self.msys_obj, self.cms_obj = topo.read_cms(self.cms_file)
        self.director = OrderParameter.getDirector(self.cms_obj,
                                                   self.director_abc_coeffs)
        self.trajectory = get_trj_from_cms_file(self.cms_file)
        self.descriptors = get_descriptors_from_file(self.descriptors_file)
        try:
            self.results = self._getResults()
        except (AslException, SmartsException) as err:
            if self.logger:
                self.logger.error(str(err))
            raise
        files = self._writeResultFiles()
        self._writeOutputCMS(files)
        log('All finished', timestamp=True, pad=True, logger=self.logger)