"""
R-Group Analysis Dead End elimination
"""
# RGA Dead End Elimination code
import math
import random
import sys
import time
from past.utils import old_div
import numpy
from schrodinger.utils import log
logger = log.get_output_logger("schrodinger.application.canvas.r_group_dee")
[docs]class SimulatedAnnealing():
[docs]    @staticmethod
    def getChoiceSpectrum(choices):
        # Build an array that lists each multimatch input structure a number
        #  of times equal to the number of matches it could make a transition
        #  to from its current match. Later, we will select a random entry
        #  from this list in order to select the structure to sample a new
        #  random transition for.
        spectrum = []
        for i, choice in enumerate(choices):
            # For example, if st[i] has 6 matches, it will be listed 5
            #  times, because it can make 5 possible transitions:
            spectrum.extend((choice - 1) * [i])
        return spectrum 
[docs]    def __init__(self,
                 energy_matrix,
                 sa_seed=None,
                 t_factor=None,
                 tmax_mult=None):
        """
        Initializer for simulated annealing class.
        :param energy_matrix: precalculated pairwise matrix used by SA.
        :type energy_matrix: `DEE_EnergyMatrix`
        :param sa_seed: SA random number generator seed
        :type sa_seed: int
        :param t_fac: Factor (<1) by which T will be multiplied at each T change
        :type sa_seed: float
        :param tmax_mult: Factor by which no. of st. will be multiplied to get starting T
        :type tmax_mult: float
        """
        # Because this class can be initialized from many places, several of which can
        #  optionally set the keyword arguments, it is most robust for all callers to
        #  use None as the default value and then let this method define the true defaults
        #  as follows.  Otherwise, we would have to reset the defaults in all the calling
        #  functions, should we decide that changes are needed:
        self.sa_seed = (0 if sa_seed is None else sa_seed)
        self.t_factor = (0.9 if t_factor is None else t_factor)
        self.tmax_mult = (6.0 if tmax_mult is None else tmax_mult)
        self.random = random.Random()
        self.random.seed(self.sa_seed)
        self.energy_matrix = energy_matrix
        self.nmax = self.energy_matrix.numPos()  # no. input structures
        self.choices = self.energy_matrix.numChoices(
        )  # for each st, no. states
        self.log10_configs = 1
        for choice in self.choices:
            self.log10_configs += math.log10(choice)
        self.choice_spectrum = SimulatedAnnealing.getChoiceSpectrum(
            self.choices)
        self.best_solution = []  # First solution encountered w. current best E
        # All solutions found with current best E:
        self.best_solution_set = set()
        # Best energy found so far, initialize to highest FP value:
        self.best_energy = sys.float_info.max
        self.cpu_time = None
        self.t_max = self.tmax_mult * self.nmax  # starting T
        self.t_min = 0.001  # Quit if T gets this low
        self.max_t_steps = int(1 + old_div(
            (math.log(self.t_min) -
             math.log(self.t_max)), math.log(self.t_factor)))
        # Epsilon for floating-point comparison:
        self.fp_eps = 1.0e-10
        self.imax = 0  # sum of number of states for each st
        for c in self.choices:
            self.imax += c 
[docs]    def run(self):
        # initialize random initial solution
        time_0 = time.perf_counter()
        clock_0 = time.perf_counter()
        solution_new = []
        for c in self.choices:
            solution_new.append(self.random.randint(0, c - 1))
        e_new = self.energy_matrix.calculateEnergy(solution_new)
        self.new_best(e_new, solution_new)
        e_old = e_new
        solution_old = solution_new[:]
        iter = 0
        t = self.t_max
        # Normally decrease T after this number of steps
        max_steps_at_t = 20 * int(math.sqrt(self.imax))
        # Decrease T if no rejections in this number of steps:
        max_steps_all_accept = old_div(max_steps_at_t, 2)  #
        tsteps = 0  # Number of steps so far
        logger.debug("Starting simulated annealing")
        logger.debug("SA random seed: %d\n" % self.sa_seed)
        logger.debug(
            "No. structures=%d; log10 configs=%.2f; sum of choices= %d" %
            (self.nmax, self.log10_configs, self.imax))
        logger.debug("Max. T levels=%d; T factor=%.2f; Max. steps at T=%d" %
                     (self.max_t_steps, self.t_factor, max_steps_at_t))
        # Number of consecutive acceptances that exhibit the current best energy:
        acceptCnt_best_e = 0
        e_diff_max = 0  # Abs value of max energy change during any iteration at any T
        e_t_start = sys.float_info.max  # Will hold E at the start of each T step
        delta_e = sys.float_info.max  # Will hold delta_e for each T step
        while True:
            last_delta_e = delta_e
            last_best_energy = self.best_energy
            tsteps += 1
            if tsteps > self.max_t_steps:
                # Exceeded max. allowed iterations:
                break
            acceptCnt = 0
            rejectCnt = 0
            totalCnt = 0
            for totalCnt in range(1, max_steps_at_t + 1):
                solution_new, istruct = self.getNewSolution(
                    solution_old)  #make a move
                #e_new = self.energy_matrix.calculateEnergy( solution_new )
                e_diff = self.energy_matrix.calculateEnergyDifference(
                    solution_old, e_old, istruct, solution_new)
                e_diff_max = max(e_diff_max, abs(e_diff))
                e_new = e_old + e_diff
                eligible_for_boltzmann = True
                # Summation of large number of FP values in E calc leads
                #  to floating-point inaccuracies that the following
                #  attempts to correct for:
                e_best_delta = abs(old_div((e_new - self.best_energy), e_new))
                is_e_best = e_best_delta < self.fp_eps  # Close enough for gummint work
                if is_e_best:
                    # Special branch to handle ground-state degeneracy:
                    #  If current state revisits an old best-E state, then
                    #  we will always reject and will not try Boltzmann:
                    eligible_for_boltzmann = self.is_another_best_config(
                        solution_new)
                if eligible_for_boltzmann and self.boltzmann_probability(
                        e_old, e_new, t) > self.random.random():
                    # We accept the new state:
                    if e_new < self.best_energy:
                        self.new_best(e_new, solution_new)
                    solution_old = solution_new[:]
                    e_old = e_new
                    acceptCnt += 1
                    if is_e_best:
                        acceptCnt_best_e += 1
                    else:
                        acceptCnt_best_e = 0
                else:
                    # We rejected the current state:
                    rejectCnt += 1
                iter += 1
                if totalCnt == max_steps_all_accept and acceptCnt == totalCnt:
                    # Go to next T early because we accepted all structures (T very high):
                    break
                if acceptCnt_best_e == max_steps_at_t:
                    # Break out early because of too many consecutive
                    #  acceptances of best_e structures:
                    break
            # Completed or broke out of a T step;
            #  Unintuitively. e_old is E at the end of the completed T step:
            self.report_t(t, e_t_start, e_old, tsteps, totalCnt, acceptCnt)
            delta_e = abs(e_old - e_t_start)
            delta_best_energy = abs(self.best_energy - last_best_energy)
            #if tsteps % 50 == 0:
            #    logger.debug( ' %4d T-steps\n' % ( tsteps, ) )
            finished = False
            if (delta_e < self.fp_eps and last_delta_e < self.fp_eps and
                    delta_best_energy < self.fp_eps):
                msg = 'Converged because energy has stopped going down'
                finished = True
            elif totalCnt == rejectCnt:
                msg = 'Converged because there were no acceptances at last T'
                finished = True
            elif acceptCnt_best_e == max_steps_at_t:
                msg = ('Converged due to massive low-E degeneracy:\n'
                       '  %d consecutive acceptances had best E' %
                       (acceptCnt_best_e,))
                finished = True
            elif tsteps > self.max_t_steps:
                msg = 'Exiting because exceeded iteration limit'
                finished = True
            if finished:
                #mod = tsteps % 50
                #if mod:
                #    pad = ''.rjust( 50 - mod ) # string of mod blanks
                #    logger.debug( '%s %4d T-steps\n' % ( pad, tsteps, ) )
                logger.debug(msg)
                break
            # If we get here, we're still running; proceed to next temperature:
            t = self.new_temperature(t)
            e_t_start = e_old
        #print "Completed simulated annealing"
        elapsed_time = time.perf_counter() - time_0
        cpu_time = time.perf_counter() - clock_0
        self.cpu_time = cpu_time
        logger.debug("\nTemperature steps: %d" % tsteps)
        logger.debug("Total iterations: %d" % iter)
        logger.debug("|DeltaEMax| for an iteration: %.2f" % e_diff_max)
        logger.debug("exp(-|DeltaEMax|/initialT)= %.2f" %
                     (math.exp(old_div(-e_diff_max, self.t_max))))
        logger.debug("Best energy: %.4f" % self.best_energy)
        #print "First best solution found: ", self.best_solution
        logger.debug("Number of degenerate best solutions found: %d" %
                     len(self.best_solution_set))
        #for sol in self.best_solution_set:
        #    print ' ', sol
        logger.debug("Elapsed time: %.2f" % elapsed_time)
        logger.debug("CPU time: %.2f" % cpu_time)
        return '' 
[docs]    def is_another_best_config(self, solution):
        s = tuple(solution[:])
        if s in self.best_solution_set:
            #print '  Found best E again; revisited solution:'
            #print ' ', s
            return False
        else:
            self.best_solution_set.add(tuple(solution[:]))
            #print '  Found best E again; new solution:'
            #print ' ', s
            return True 
[docs]    def new_best(self, energy, solution):
        self.best_energy = energy
        self.best_solution = tuple(solution[:])
        self.best_solution_set = set()
        self.best_solution_set.add(tuple(solution[:])) 
        #print '  New best_E=%.10f' % ( energy, )
[docs]    def report_t(self, t, start_E, end_E, tsteps, totalCnt, acceptCnt):
        delta_E = end_E - start_E
        if tsteps == 1:
            logger.debug(
                "T, start_E, end_E, delta_E, best_E=" +
                " %10.4f %10s %10.4f %10s %10.4f" %
                (t, '----------', end_E, '----------', self.best_energy))
        else:
            logger.debug("T, start_E, end_E, delta_E, best_E=" +
                         " %10.4f %10.4f %10.4f %10.4f %10.4f" %
                         (t, start_E, end_E, delta_E, self.best_energy))
        logger.debug("  T-steps, T-iterations, T-acceptances, %d %d %d" %
                     (tsteps, totalCnt, acceptCnt)) 
[docs]    def getNewSolution(self, solution_old):
        # Original method:
        #solution_new = self.neighbor( solution_old )
        #return solution_new
        solution_new = solution_old[:]
        istruct = self.random.choice(self.choice_spectrum)
        nchoices = self.choices[istruct]
        choice_increment = self.random.randint(1, nchoices - 1)
        old_choice = solution_old[istruct]
        new_choice = (old_choice + choice_increment) % nchoices
        #print 'istruct, old_choice, new_choice=', istruct, old_choice, new_choice
        solution_new[istruct] = new_choice
        return (solution_new, istruct) 
[docs]    def neighbor(self, solution):
        s = solution[:]
        while True:
            mol = self.random.randint(0, self.nmax - 1)
            c = self.random.randint(0, self.choices[mol] - 1)
            s[mol] = c
            if s != solution:  # make sure we have a new system state
                break
        return s 
[docs]    def new_temperature(self, old_t):
        t = self.t_factor * old_t
        return t 
[docs]    def boltzmann_probability(self, e_old, e_new, t):
        if e_new < e_old:
            return 1.0
        elif t == 0.0:
            return 0.0
        else:
            return math.exp(old_div((e_old - e_new), t)) 
[docs]    def getBestMatch(self):
        return self.best_solution  
[docs]class DEE_Backtracking():
[docs]    def __init__(self, energy_matrix):
        self.energy_matrix = energy_matrix
        self.nmax = self.energy_matrix.numPos()
        self.choices = self.energy_matrix.numChoices()
        self.result = self.energy_matrix.initialSolution()
        self.current_energy = sys.float_info.max
        if len(self.result) > 0:
            self.current_energy = self.energy_matrix.calculateEnergy(
                self.result)
        #self.result = []
        self.count = 0
        self.nodeCnt = 0
        self.maxCnt = 25000 
[docs]    def minimize(self):
        print("starting backtracking algorithm")
        print("nmax: ", self.nmax, "   choices: ", self.choices)
        print("initial solution: ", self.result)
        print("initial energy:   ", self.current_energy)
        solution = []
        rStr = ""
        if self.backtrack(solution):
            rStr = "Backtracking algorithm max node count exceeded. Solution may be suboptimal."
        print("nodes visited: ", self.nodeCnt)
        self.output(self.result)
        print("final solution: ", self.result)
        print("final energy:   ", self.current_energy)
        return rStr 
[docs]    def backtrack(self, solution):
        self.nodeCnt += 1
        if self.nodeCnt > self.maxCnt:
            return True
        if self.reject(solution):
            return None
        if self.accept(solution):
            self.result = solution[:]
        s = self.first(solution)
        while s is not None:
            if self.backtrack(s):
                return True
            s = self.next(s) 
[docs]    def reject(self, solution):
        return self.energy_matrix.applyEnergyFilter(solution,
                                                    self.current_energy) 
[docs]    def accept(self, solution):
        if len(solution) == self.nmax:
            self.count += 1
            e = self.energy_matrix.calculateEnergy(solution)
            if e < self.current_energy:
                self.current_energy = e
                #print "solution: ", solution, "  new best energy: ", e
                return True
            else:
                #print "current suboptimal solution: ", solution, "  e: ", e
                return False 
[docs]    def first(self, solution):
        if len(solution) == self.nmax:
            return None
        else:
            s = solution[:]
            s.append(0)
            return s 
[docs]    def next(self, solution):
        s = solution[:]
        if s[len(s) - 1] < self.choices[len(s) - 1] - 1:
            s[len(s) - 1] = s[len(s) - 1] + 1
            return s
        else:
            return None 
[docs]    def output(self, solution):
        i0 = 0
        total_variants = 1
        outStr = ''
        for c, r in enumerate(solution):
            idx = i0 + r + 1
            outStr += str(idx)
            outStr += ' '
            i0 += self.choices[c]
            total_variants = total_variants * self.choices[c]
        outStr += "\n" 
        #print "total variants: ", total_variants, "  total checked: ", self.count
        #print "min energy: ", self.current_energy
        #print "states: ", outStr
[docs]    def getBestMatch(self):
        return self.result  
[docs]class DEE_EnergyMatrix():
[docs]    def __init__(self, choices, uij):
        self.nmax = len(choices)
        self.choices = choices
        self.uij = uij
        self.offset = []
        self.singles = []
        # calculate total number of matrix rows/columns
        ns = 0
        for i in range(self.nmax):
            ns += self.choices[i]
        #print "matrix size ns: ", ns
        #print "choices: ", self.choices
        self.umin = numpy.zeros((self.nmax, self.nmax))
        self.umin_state = numpy.zeros((ns, self.nmax))
        # precalculate min values of uij for each pair of molecules (self.umin)
        # and each state/molecule pair (self.umin_state)
        i0 = 0
        for i in range(self.nmax):
            j0 = 0
            for j in range(self.nmax):
                self.umin[i][j] = sys.float_info.max
                for ia in range(self.choices[i]):
                    idx = i0 + ia
                    self.umin_state[idx][j] = sys.float_info.max
                    for ja in range(self.choices[j]):
                        idy = j0 + ja
                        if self.uij[idx][idy] < self.umin[i, j]:
                            self.umin[i, j] = self.uij[idx][idy]
                        if self.uij[idx][idy] < self.umin_state[idx, j]:
                            self.umin_state[idx, j] = self.uij[idx][idy]
                j0 += self.choices[j]
            self.offset.append(i0)
            i0 += self.choices[i]
        self.pairs = numpy.zeros((i0, i0)) 
    #print "energy matrix: "
    #numpy.set_printoptions(precision=2,edgeitems=8)
    #print self.uij
    #print "umin: "
    #print self.umin
    #print "umin_state: "
    #print self.umin_state
[docs]    def numPos(self):
        return self.nmax 
[docs]    def numChoices(self):
        return self.choices 
[docs]    def calculateEnergy(self, solution):
        e = 0.0
        s = self.convertSolution(solution)
        for r in s:
            e += self.uij[r][r]
            for c in [x for x in s if x > r]:
                e += self.uij[r][c]
        return e 
[docs]    def calculateEnergyDifference(self, solution_old, e_old, istruct,
                                  solution_new):
        s_old = self.convertSolution(solution_old)
        s_new = self.convertSolution(solution_new)
        e_diff = 0.0
        c_old = s_old[istruct]
        c_new = s_new[istruct]
        for r in s_old:
            e_diff -= self.uij[r][c_old]
        for r in s_new:
            e_diff += self.uij[r][c_new]
        return e_diff 
[docs]    def initialSolution(self):
        solution = []
        solution.append(0)
        for i in range(1, self.nmax):
            umin = sys.float_info.max
            imin = -1
            for ia in range(self.choices[i]):
                idx = self.offset[i] + ia
                if idx in self.singles:
                    continue
                if self.pairs[0][idx] == 1:
                    continue
                if self.uij[0][idx] < umin:
                    umin = self.uij[0][idx]
                    imin = ia
            if imin == -1:
                blank = []
                return blank
            solution.append(imin)
        return solution 
    # this function checks that this solution does not contain eliminated singles
    # of dead end pairs
[docs]    def checkSolution(self, solution):
        s = self.convertSolution(solution)
        for r in s:
            if r in self.singles:
                return False
            for c in s:
                if self.pairs[r][c] == 1:
                    return False
        return True 
[docs]    def convertSolution(self, solution):
        s = []
        for c, r in enumerate(solution):
            idx = self.offset[c] + r
            s.append(idx)
        return s 
[docs]    def applyEnergyFilter(self, solution, current_energy):
        if len(solution) == self.nmax:
            return False
        s = self.convertSolution(solution)
        if not self.checkSolution(solution):
            return True
        e = self.calculateEnergy(solution)
        # estimate best possible contribution from the remaining states
        for i in range(self.nmax):
            if i < len(solution):
                for j in range(len(solution), self.nmax):
                    e += self.umin_state[s[i]][j]
            else:
                e += self.umin[i][i]
                for j in range(i + 1, self.nmax):
                    e += self.umin[i][j]
        if e >= current_energy:
            return True
        else:
            return False 
[docs]    def eliminateSingles(self):
        #print "eliminate singles using Goldstein rule"
        for k in range(self.nmax):
            for a in range(self.choices[k] - 1):
                ka = self.offset[k] + a
                if ka in self.singles:
                    continue
                for b in range(a + 1, self.choices[k]):
                    kb = self.offset[k] + b
                    if kb in self.singles:
                        continue
                    rc = self.applyGoldsteinSingles(k, a, b)
                    if rc == 3:
                        self.singles.append(ka)
                        self.singles.append(kb)
                    if rc == 1:
                        self.singles.append(ka)
                    if rc == 2:
                        self.singles.append(kb) 
        #print "eliminate singles size: ", len(self.singles)
        #print "singles: ", self.singles
[docs]    def eliminatePairs(self):
        #print "eliminate pairs"
        cnt = 0
        for k in range(self.nmax - 1):
            for l in range(k + 1, self.nmax):
                for a in range(self.choices[k]):
                    ka = self.offset[k] + a
                    for b in range(self.choices[l]):
                        lb = self.offset[l] + b
                        eDiffAB = self.uij[ka][ka] + self.uij[lb][
                            lb] + self.uij[ka][lb]
                        isElim = False
                        for c in range(self.choices[k]):
                            if isElim:
                                break
                            if c == a:
                                continue
                            kc = self.offset[k] + c
                            for d in range(self.choices[l]):
                                if isElim:
                                    break
                                if d == b:
                                    continue
                                ld = self.offset[l] + d
                                if self.pairs[kc][ld] == 1:
                                    continue
                                eDiffCD = self.uij[kc][kc] + self.uij[ld][
                                    ld] + self.uij[kc][ld]
                                totMinAB = 0.0
                                totMaxCD = 0.0
                                for i in range(self.nmax):
                                    if i == k or i == l:
                                        continue
                                    minDiffAB = sys.float_info.max
                                    maxDiffCD = -sys.float_info.max
                                    for c in range(self.choices[i]):
                                        ix = self.offset[i] + c
                                        eab = self.uij[ka][ix] + self.uij[lb][ix]
                                        ecd = self.uij[kc][ix] + self.uij[ld][ix]
                                        if eab < minDiffAB:
                                            minDiffAB = eab
                                        if ecd > maxDiffCD:
                                            maxDiffCD = ecd
                                    totMinAB += minDiffAB
                                    totMaxCD += maxDiffCD
                                if (eDiffAB + totMinAB) > (eDiffCD + totMaxCD):
                                    cnt += 1
                                    isElim = True
                                    self.pairs[ka][lb] = 1
                                    self.pairs[lb][ka] = 1 
                                    #print "eliminate ka: ", ka, "  lb: ", lb, "  total count: ", cnt
                                #print "total pairs eliminated: ", cnt
[docs]    def applyGoldsteinSingles(self, k, a, b):
        #print "goldstein singles k: ", k, "  a: ", a, "  b: ", b
        ka = self.offset[k] + a
        kb = self.offset[k] + b
        bNotA = False  # if A can't be eliminated this is true
        bNotB = False  # if B can't be eliminated this is true
        if ka in self.singles or kb in self.singles:
            return 0
        eDiffAB = self.uij[ka][ka] - self.uij[kb][kb]
        eDiffBA = -eDiffAB
        for i in range(self.nmax):
            if i == k:
                continue
            initA = False
            initB = False
            minDiffAB = sys.float_info.max
            minDiffBA = sys.float_info.max
            for c in range(self.choices[i]):
                ix = self.offset[i] + c
                if ix in self.singles:
                    continue
                if self.pairs[ka][ix] == 1:
                    if self.pairs[kb][ix] == 1:
                        continue
                    else:
                        bNotB = True
                elif self.pairs[kb][ix] == 1:
                    bNotA = True
                # check if we can not eliminate either A or B
                if bNotA and bNotB:
                    return 0
                eDiff = self.uij[ka][ix] - self.uij[kb][ix]
                if self.pairs[ka][ix] == 0:
                    if not initA:
                        minDiffAB = eDiff
                        initA = True
                    if minDiffAB > eDiff:
                        minDiffAB = eDiff
                if self.pairs[kb][ix] == 0:
                    if not initB:
                        minDiffBA = -eDiff
                        initB = True
                    if minDiffBA > -eDiff:
                        minDiffBA = -eDiff
            if not (initA or initB):
                return 3
            elif not initA:
                return 1
            elif not initB:
                return 2
            eDiffAB += minDiffAB
            eDiffBA += minDiffBA
        #print "diffAB: ", eDiffAB, "   diffBA: ", eDiffBA
        if (not bNotA) and (eDiffAB > 0):
            return 1
        if (not bNotB) and (eDiffBA > 0):
            return 2
        return 0  
[docs]def main():
    matrix = DEE_EnergyMatrix()
    backtrack = DEE_Backtracking(matrix)
    backtrack.minimize() 
if __name__ == '__main__':
    main()