"""
Classes related to Monte Carlo simulations
Copyright Schrodinger, LLC. All rights reserved.
"""
import contextlib
import logging
import math
import random
import time
from past.utils import old_div
import numpy
import scipy.constants as sciconst
import scipy.spatial as spatial
from schrodinger.application.matsci import amorphous
from schrodinger.infra import mm
from schrodinger.structutils import minimize
from schrodinger.structutils import transform
K_KCAL_PER_MOL = sciconst.k * sciconst.N_A / (4.184 * 1000)
[docs]class ClashChecker(amorphous.BuilderWithClashDetection):
    """
    A builder used just for checking for clashes during Monte Carlo simulations
    """
[docs]    def checkForRings(self, struct):
        """
        Check for rings in struct and cache them
        :type struct: `schrodinger.structure.Structure`
        :param struct: The structure to check for rings
        """
        self.rings = list(struct.find_rings())
        self.has_rings = bool(self.rings) 
[docs]    def findRings(self, struct):
        """
        Return the cache of found rings - struct is not used because the
        structure may change coordinates but never bonding during an MC run. It
        is kept for API compatibility with the parent class.
        :param struct: unused
        :rtype: list
        :return: The list of found rings
        """
        return self.rings  
[docs]class MonteCarloMolecule(object):
    """
    Class for treatment of a moving molecule during a Monte Carlo iteration
    """
[docs]    def __init__(self, cell, molnum, box):
        """
        Create a MonteCarloMolecule object
        :type cell: `schrodinger.structure.Structure`
        :param cell: The complete cell of molecules
        :type molnum: int
        :param molnum: The molecule number that will be moved
        :type box: `schrodinger.application.matsci.amorphous.Box`
        :type box: The Box object that contains information about the system
            size
        """
        self.molnum = molnum
        self.cell = cell
        self.molecule = cell.molecule[molnum]
        self.atoms = self.molecule.getAtomIndices()
        self.molstruct = self.molecule.extractStructure()
        self.original_coords = self.getCoordinates()
        self.box = box 
[docs]    def randomlyRotate(self, max_degrees):
        """
        Randomly rotate the molecule by no more than the given number of degrees
        :type max_degrees: float
        :param max_degrees: The largest rotation allowed
        """
        radians = numpy.radians(max_degrees)
        amorphous.random_rotation(self.molstruct, max_rotation=radians)
        self.updateCoordinates() 
[docs]    def randomlyTranslate(self, mu_sigma):
        """
        Randomly translate the molecule. The distribution of translation
        distances is given by a lognormal distribution.
        :type mu_sigma: (float, float)
        :param mu_sigma: (mu, sigma). mu gives the mean value of the lognormal
            distribution the translation is taken from. sigma gives the standard
            deviation of the distribution
        """
        mu, sigma = mu_sigma
        vector = numpy.array([random.random() for x in range(3)])
        normvec = transform.get_normalized_vector(vector)
        dist = random.lognormvariate(mu, sigma)
        tvec = dist * normvec
        transform.translate_structure(self.molstruct, *tvec)
        centroid = transform.get_centroid(self.molstruct)[:3]
        if not self.box.isPointInBox(centroid):
            # Point falls outside the box, shift it to its mirror image inside
            # the box
            shifter = self.box.getTranslationToBox(centroid)
            transform.translate_structure(self.molstruct, *shifter)
            new_centroid = [a + b for a, b in zip(centroid, shifter)]
        else:
            new_centroid = centroid
        # Ensure the point falls in the valid region of the box
        if self.box.isValidPoint(new_centroid):
            self.updateCoordinates() 
[docs]    def updateCoordinates(self):
        """
        Update the coordinates of this molecule within the entire cell
        """
        coords = self.getCoordinates()
        for xyz, index in zip(coords, self.atoms):
            atom = self.cell.atom[index]
            atom.xyz = xyz 
[docs]    def getCoordinates(self):
        """
        Get the xyz coordinate for this molecule
        :rtype: `numpy.array`
        :return: The xyz coordinates of this molecule
        """
        return self.molstruct.getXYZ()  
[docs]class Metropolizer(object):
    """
    A class that runs a Monte Carlo simulation using the Metropolis algorithm
    """
    ROTATE, TRANSLATE = list(range(2))
[docs]    def __init__(self,
                 scaffold,
                 cell=None,
                 weight_rotate=0.5,
                 weight_translate=0.5,
                 translate_mu=0.25,
                 translate_sigma=1.0,
                 max_rotate=360.0,
                 temperatures=None,
                 iterations=10000,
                 clash_penalty=50.0,
                 minimize_interval=None,
                 forcefield=minimize.OPLS_2005,
                 vdw_scale=1.0,
                 gravity=True,
                 gravity_weight=4.0,
                 logger=None,
                 cleanup=True):
        """
        Create a Metropolizer object
        :type scaffold: `amorphous.Scaffold`
        :param scaffold: The scaffold object that controls the cell structure
        :type cell: `schrodinger.structure.Structure`
        :param cell: The structure containing the molecules to move via Monte
            Carlo
        :type weight_rotate: float
        :param weight_rotate: The weight of rotation when randomly choosing to
            rotate or translate
        :type weight_translate: float
        :param weight_translate: The weight of translations when randomly
            choosing to rotate or translate
        :type translate_mu: float
        :param translate_mu: The mean of the natural logarithm function for the
            log-normal distribution of translation distances.
        :type translate_sigma: float
        :param translate_sigma: The standard deviation of the natural logarithm
            function for the log-normal distribution of translation distances.
        :type max_rotate: float
        :param max_rotate: The maximum number of degrees for any rotation
        :type temperatures: list of float
        :param temperatures: A list of temperatures to run the annealing at
        :type iterations: int
        :param iterations: The number of Monte Carlo iterations to run at each
            temperature
        :type clash_penalty: float
        :param clash_penalty: Penalty for clashes
        :type minimize_interval: int
        :param minimize_interval: Do a minimization after every Xth interval.
            Not implemented at this time.
        :type forcefield: int or None
        :param forcefield: The mmffld number of the forcefield to use for energy
            evaluations. Use None to turn off forcefield energy evaluations.
        :type vdw_scale: float
        :param vdw_scale: The VdW scale factor to use for clash checking
        :type gravity: bool
        :param gravity: Whether to use the gravity term. Gravity attracts all
            molecules toward the scaffold if a scaffold molecule is present, or
            the center of the cell if no scaffold is present. If no forcefield
            term is included, then a simple hard shell model is used to prevent
            clashes caused by gravity.
        :type logger: `logging.Logger`
        :param logger: The logger for this class
        :type cleanup: bool
        :param cleanup: Attempt to clean up the Lewis structure before
            evaluating the energy. Only relevant if forcefield is not None.
        """
        self.cell = cell
        self.scaffold = scaffold
        self.box = scaffold.box
        weight_sum = weight_translate + weight_rotate
        self.rotate_chance = old_div(weight_rotate, weight_sum)
        self.limits = {}
        self.limits[self.ROTATE] = max_rotate
        self.limits[self.TRANSLATE] = (translate_mu, translate_sigma)
        self.temperatures = temperatures or [300.0]
        self.iterations = iterations
        self.clash_penalty = clash_penalty
        self.minimize_interval = minimize_interval
        self.forcefield = forcefield
        self.gravity = gravity
        self.gravity_weight = gravity_weight
        self.clash_vdw_scale = vdw_scale
        self.cleanup = cleanup
        self.logger = logger
        self.setupGravity() 
[docs]    def setupGravity(self):
        """
        Pre-compute data for the gravity term
        """
        if not self.gravity:
            self.clasher = None
            self.ckd_tree = None
            self.gravity_cutoff = numpy.inf
            self.gravity_clashes = 0
        else:
            if self.scaffold:
                # The cdk_tree is used to find the closest approach between a
                # molecule an any scaffold atoms
                self.ckd_tree = spatial.cKDTree(self.scaffold.scaffold.getXYZ())
            else:
                self.ckd_tree = None
            if self.box:
                self.gravity_cutoff = self.box.getLargestSpan()
                self.gravity_center = self.box.getCentroid()
            else:
                self.gravity_cutoff = numpy.inf
                self.gravity_center = numpy.array([0., 0., 0.])
            self.clasher = ClashChecker(logger=self.logger)
            if self.cell:
                self.clasher.checkForRings(self.cell) 
[docs]    def getNumClashes(self, struct):
        """
        Get the number of clashes for the proposed structure
        :type struct: `schrodinger.structure.Structure`
        :param struct: The structure to check for clashes
        :rtype: int
        :return: The total number of clashes found
        """
        clashes = self.clasher.getClashes(struct, self.clash_vdw_scale)
        return sum([len(x) for x in clashes.values()]) 
[docs]    def getClosestApproach(self, coords):
        """
        Get the closest approach between the given set of coordinates and the
        scaffold molecule or gravity center if no scaffold.
        :type coords: `numpy.array`
        :param coords: The XYZ coordinates to check for close approach to the
            scaffold - such as from the getXYZ()
        :rtype: float
        :return: The closest approach between coords and the scaffold, or the
            gravity center if no scaffold was used.
        """
        if self.ckd_tree:
            cut = self.gravity_cutoff
            distances, unused = self.ckd_tree.query(coords,
                                                    distance_upper_bound=cut)
            # Use a large number to avoid numpy.inf if no neighbors found
            return min(10000., min(distances))
        else:
            xyz = coords - self.gravity_center
            return min([transform.get_vector_magnitude(v) for v in xyz]) 
[docs]    def getGravityEnergy(self, target):
        """
        Evaluate the gravitational energy of the given target. The energy is
        simply the difference of the original distance between the target and
        the gravitational source and the new distance between them.
        :type target: `MonteCarloMolecule`
        :param target: A molecule that has been randomly moved
        :rtype: float
        :return: The gravitational energy of the target's new position
        """
        if not self.ckd_tree:
            return 0.0
        cut = self.gravity_cutoff
        old_distance = self.getClosestApproach(target.original_coords)
        new_distance = self.getClosestApproach(target.getCoordinates())
        return (new_distance - old_distance) * self.gravity_weight 
[docs]    def getClashPenalty(self, candidate):
        """
        Get the energy penalty due to clashes
        :type candidate: `schrodinger.structure.Structure`
        :param candidate: The structure to check for clashes
        :rtype: float
        :return: The penalty based on the number of clashes
        """
        return self.clash_penalty * self.getNumClashes(candidate) 
[docs]    def simulate(self):
        """
        Run the Monte Carlo simulated annealing
        """
        # Because simulate manages the same structure across each iteration,
        # open a minimzer context outside the loop to avoid repeat calls to
        # mmlewis_apply and mmffld_enterMol, instead only updating coordinates
        # for each call to getEnergy; ExitStack is used to optionally
        # open/close that context if self.forcefield is defined
        with contextlib.ExitStack() as stack:
            if self.forcefield:
                minimizer = stack.enter_context(
                    minimize.minimizer_context(ffld_version=self.forcefield,
                                               cleanup=self.cleanup))
            else:
                minimizer = None
            self._simulate(minimizer) 
    def _simulate(self, minimizer):
        if not self.cell:
            raise RuntimeError('The cell structure must be set before '
                               'simulation')
        # Set up the minimizer if needed
        if minimizer:
            # Turn off force calculation - we want energies only
            mm.mmffld_setOption(minimizer.handle, mm.MMFfldOption_ENE_NO_FORCES,
                                1, 0.0, "")
            # Must use a copy of the original cell structure here, because we
            # are continually updating those coordinates to get the energy of
            # the new structure
            minimizer.setStructure(self.cell.copy())
            self.scaffold.setMinimizerPBCProperties(minimizer.handle)
            # Turning off pair list updates saves some time - FIXME - stopped
            # doing this because we don't have a good way of knowing when we
            # should regenerate them. MATSCI-2022
            # mm.mmffld_setOption(minimizer.handle,
            #   mm.MMFfldOption_SYS_PAIR_LIST_UPDATE, 0, 0.0, "")
        # Initialize some data
        # Make sure we don't move any scaffold molecules
        first_mol = self.findFirstDisorderedMolecule()
        last_mol = self.cell.mol_total
        total_energy, ffen, graven, clashen = self.getEnergy(
            minimizer, self.cell)
        previous_energy = ffen + clashen
        for temp in self.temperatures:
            start = time.time()
            self.kt = temp * K_KCAL_PER_MOL
            num_accepted = 0
            for cycle in range(1, self.iterations + 1):
                # Pick a molecule to move, move it and compute the energy
                candidate = self.cell.copy()
                candidate.title = str(cycle) + '_' + str(int(temp))
                target = self.getTargetMolecule(candidate, first_mol, last_mol)
                move = self.performMovement(target)
                total_energy, ffen, graven, clashen = self.getEnergy(
                    minimizer, candidate, target=target)
                # Determine if the move is accepted, and if so, update data
                accepted = False
                delta = total_energy - previous_energy
                if self.isAccepted(previous_energy, total_energy):
                    self.cell = candidate
                    # We do not include the gravitational energy because that
                    # energy is always computed as a delta from the previous
                    # geometry rather than as a raw total
                    previous_energy = ffen + clashen
                    accepted = True
                    num_accepted = num_accepted + 1
            # Log information about this temperature cycle
            end = time.time()
            total_time = end - start
            self.log('Seconds for %.2f temp cycle: %.1f' % (temp, total_time))
            self.log('Accepted %d out of %d steps' %
                     (num_accepted, self.iterations))
        return self.cell
[docs]    def getTargetMolecule(self, candidate, first, last):
        """
        Select the molecule to move this iteration
        :type candidate: `schrodinger.structure.Structure`
        :param candidate: The entire cell containing all molecules
        :type first: int
        :param first: The first valid molecule number to pick
        :type last: int
        :param last: The last valid molecule number to pick
        :rtype: `MonteCarloMolecule`
        :return: The MCM object for the chosen molecule
        """
        molnum = random.randint(first, last)
        return MonteCarloMolecule(candidate, molnum, self.box) 
[docs]    def findFirstDisorderedMolecule(self):
        """
        Find the first molecule number that isn't part of the scaffold
        :rtype: int
        :return: The first non-scaffold molecule number
        """
        if self.scaffold:
            scaf_atoms = self.scaffold.scaffold.atom_total
            first_non_scaff_atom = self.cell.atom[scaf_atoms + 1]
            return first_non_scaff_atom.molecule_number
        else:
            return 1 
[docs]    def getEnergy(self, minimizer, candidate, target=None):
        """
        Compute the total energy of the system
        :type candidate: `schrodinger.structure.Structure`
        :param candidate: The entire cell containing all molecules
        :type target: `MonteCarloMolecule`
        :param target: The MCM object for the just-moved molecule (None if this
            is the 0th iteration
        :rtype: (float, float, float, float)
        :return: The total energy, forcefield energy, gravitational energy and
            clash energy (total energy is the sum of the last three)
        """
        ffen = graven = clashen = 0.0
        if minimizer:
            minimizer.updateCoordinates(candidate)
            ffen = minimizer.getEnergy()
        if self.gravity and target:
            graven = self.getGravityEnergy(target)
        if self.gravity and not self.forcefield:
            clashen = self.getClashPenalty(candidate)
        total = ffen + graven + clashen
        return total, ffen, graven, clashen 
[docs]    def isAccepted(self, old_energy, new_energy):
        """
        Use the Metropolis equation to determine if the move is accepted
        :type old_energy: float
        :param old_energy: The previous energy
        :type new_energy: float
        :param new_energy: The new energy
        :rtype: bool
        :return: Whether the move is accepted or not
        """
        if new_energy < old_energy:
            return True
        probability = math.exp(old_div(-(new_energy - old_energy), self.kt))
        die_roll = random.random()
        return die_roll < probability 
[docs]    def log(self, msg, level=logging.INFO):
        """
        Add a message to the log file
        :type msg: str
        :param msg: The message to add
        :type level: int
        :param level: A `logging` priority level of the message
        """
        if not self.logger:
            return
        self.logger.log(level, msg)