"""
An interface to the Maestro color palette and color schemes.
Color schemes are read from `maestro-v<version>/data/res/scheme.res`.
Copyright Schrodinger, LLC. All rights reserved.
"""
#Contributors: Jeff A. Saunders, Matvey Adzhigirey, Herc Silverstein
import os
from schrodinger.infra import mm
from schrodinger.infra.mmbitset import Bitset
from schrodinger.infra.mminit import Initializer
from schrodinger.job.util import hunt
from schrodinger.utils import colorscheme as maestro_cscheme
from schrodinger.utils.fileutils import get_mmshare_data_dir
_initializer = Initializer([mm.mmcolor_initialize], [mm.mmcolor_terminate])
_scheme_dict = None
_scheme_name_dict = None
#############################################################################
# FUNCTIONS
#############################################################################
[docs]def get_rgb_from_color_index(index):
    """
    A convenience function for getting the rgb value for a color index.
    :param index: an integer specifying the color index
    :return r,g,b: a tuple of 3 integer values for the red, green, blue
                   values.  Range from 0 to 255.
    """
    return tuple(mm.mmcolor_index_to_vector(index)) 
#############################################################################
## CLASSES
#############################################################################
[docs]class Color(object):
    """
    Represent a color as either an integer (colormap index), string
    (color name or hex "RRGGBB" value), or an RGB value
    (tuple/list of 3 ints, values 0-255).
    Provides the following properties and methods:
        - Color.index = int(Color) - mmcolor index of the closest color
        - Color.name = str(Color) - mmcolor name of the closest color
        - Color.rgb - (tuple of 0-255 ints)
        - equal = (col1 == col2)
    When object is initialized from the RGB value, the Color.index and
    Color.name attributes are set to the closest color in the mmcolor palette.
    """
[docs]    def __init__(self, color):
        if isinstance(color, int):
            # Color index (int)
            self.index = color
            try:
                self.name = mm.mmcolor_index_to_name(color)
            except mm.MmException:
                raise ValueError("Invalid color index: %i" % color)
            self.rgb = tuple(mm.mmcolor_index_to_vector(self.index))
        elif isinstance(color, str):
            # Color name or hex 'RRGGBB' string value
            try:
                # Attempt reading as a name first
                self.index = mm.mmcolor_name_to_index(color)
                self.rgb = tuple(mm.mmcolor_index_to_vector(self.index))
                self.name = color
            except mm.MmException:
                # Not a valid name; parse the string as an RRGGBB value:
                try:
                    self.rgb = tuple(mm.mmcolor_name_to_vector(color))
                    self.index = mm.mmcolor_vector_to_index(self.rgb)
                    self.name = mm.mmcolor_index_to_name(self.index)
                except mm.MmException:
                    raise ValueError('Invalid color name/RRGGBB string: "%s"' %
                                     color)
        elif isinstance(color, list) or isinstance(color, tuple):
            # RGB value
            if len(color) != 3:
                raise ValueError("Color must be a tuple of 3 ints")
            for value in color:
                if type(value) != int:
                    raise ValueError("Color must be a tuple of 3 ints")
            self.index = mm.mmcolor_vector_to_index(color)
            if self.index == 0:
                raise ValueError("mmcolor_vector_to_index() failed.")
            self.name = mm.mmcolor_index_to_name(self.index)
            self.rgb = tuple(color)
        else:
            raise TypeError('Invalid color: %s' % color) 
    def __str__(self):
        return self.name
    def __repr__(self):
        return self.name
    def __index__(self):
        return self.index
    def __eq__(self, other):
        thetype = type(other)
        if thetype == type(1):
            return other == self.index
        elif thetype == type("name"):
            return other == self.name
        elif thetype == type(self):
            return other.index == self.index
        else:  # Any other type passed
            return False
    def __ne__(self, other):
        return not self == other
    @property
    def rgb_float(self):
        """
        Returns a tuple of (R, G, B) for this color, each ranging from 0.0 to
        1.0.
        """
        return [x / 255.0 for x in self.rgb]
    @property
    def hex_string(self):
        """
        Returns the color as string of hex RGB values (RRGGBB). For example,
        pure red will be returned as "FF0000".
        """
        return mm.mmcolor_vector_to_string(self.rgb) 
[docs]class ColorScheme(maestro_cscheme.MM_CScheme):
    """
    Define a Maestro color scheme.
    This class provides the following functionality::
        colorscheme.apply(st, [atoms])
        for color, asl in colorscheme:
            <do>
    """
[docs]    def __init__(self, name='element', color_asl_list=[]):  # noqa: M511
        """
        Create ColorScheme object
        :param name: Name of ColorScheme
        :type name: str
        :param color_asl_list: List of asl patterns in ColorScheme
        :type color_asl_list: `asl patterns`
        """
        self._name = name
        self._color_asl_list = []
        self.filename = ''.join([c if c.isalnum() else "_" for c in name
                                ]).lower() + '.sch'
        # short and long name are the same, but required for maestro cscheme
        sl_name = self.filename.upper()
        super(ColorScheme, self).__init__(short_name=sl_name,
                                          long_name=sl_name,
                                          original_name=name)
        if color_asl_list:
            for color, asl in color_asl_list:
                self.add(color, asl) 
[docs]    def add(self, color_str, asl, rule_description=''):
        """
        Add another set of rules to this color scheme.
        :param color: color string
        :param asl: what to apply the color to
        """
        # Currently prints a warning saying that the color name is not a valid
        # color for the current color scheme
        scheme_rule = maestro_cscheme.MM_CSchemeRule(
            color_name=str(color_str),
            asl_spec=asl,
            scheme_name=str(self._name),
            description=str(rule_description))
        # Remove python ownership of the scheme_rule so that its memory is
        # managed only by the c destructor. Will cause seg fault if this is not
        # set.
        scheme_rule.thisown = 0
        super(ColorScheme, self).addRule(scheme_rule)
        self._color_asl_list.append((color_str, asl)) 
    def __iter__(self):
        """
        Iterate over all entries in this scheme.
        Returns a tuple of (Color, asl)
        """
        for rule in super(ColorScheme, self).getRules():
            yield rule
[docs]    def __len__(self):
        """
        Return the number of rules in the scheme
        """
        return len(super(ColorScheme, self).getRules()) 
[docs]    def copy(self):
        """
        Return a copy of this scheme.
        """
        # FIXME currently does not copy the rule description
        dup = ColorScheme(self._name, self._color_asl_list[:])
        return dup 
[docs]    def apply(self, st, atoms=None):
        """
        Applies the scheme to the specified Structure <st>.
        :param atoms: Optionally specify which atoms to apply the scheme to
                      in format. Can be a list atom atom indices, or a
                      Bitset instance.
        """
        num_atoms = st.atom_total
        # Make a Bitset from the list of specified atoms:
        if atoms is None:
            atoms_bs = Bitset(size=num_atoms)
            atoms_bs.fill()
        elif type(atoms) == type(Bitset):
            atoms_bs = atoms
        elif type(atoms) == type(''):  # ASL expression
            # NOT recommended, as the evaluation can be slow
            atoms_bs = Bitset(size=num_atoms)
            mm.mmasl_parse_input(atoms, atoms_bs, st, 1)
        else:  # Python list or iterator (assumed)
            atoms_bs = Bitset(size=num_atoms)
            for atom in atoms:
                atoms_bs.set(atom)
        mm.mmasl_initialize(mm.error_handler)
        try:
            super(ColorScheme, self).applyScheme(st, atoms_bs)
        finally:
            mm.mmasl_terminate() 
[docs]    def writeSchemeFile(self, filename):
        """
        Write the scheme to the specified .sch file.
        :param filename: filelocation to save scheme file to
        :type filename: str
        """
        super(ColorScheme, self).setFileName(filename)
        super(ColorScheme, self).writeFile()  
[docs]class ColorRamp(object):
    """
    An object for calculating colors on a customizable color ramp.
    Coloring atoms according to a calculated property that ranges from 0 to 10::
        color_ramp = ColorRamp(colors=("white", "blue"), values=(0,10))
        for atom in st.atom:
            property = calc_property(atom)
            r, g, b  = color_ramp.getRGB(property)
            atom.setColorRGB(r, g, b)
    Coloring atoms according to a calculated property that ranges from -10
    to 10 using blues for negative values and reds for positive values::
        color_ramp = ColorRamp(colors=("blue", "white", "red"),
                               values=(-10, 0, 10))
        for atom in st.atom:
            property = calc_property(atom)
            color = color_ramp.getRGB(property)
            atom.setColorRGB(*color)
    """
[docs]    def __init__(self, colors=("white", "blue"), values=(0, 100)):
        """
        Initialize a ColorRamp object where the specified values correspond to
        the given colors
        :param colors: The list of colors.  Any color description that is
            recognized by `Color` may be used (a color name or colormap index).
        :type colors: list or tuple
        :param values: The list of numerical values.  This list must be the same
            length as `colors`, all values must be unique, and the list must be
            sorted in either ascending or descending order.
        :type values: list or tuple
        """
        if len(colors) != len(values):
            raise ValueError("Color and value lists must be of equal length")
        if values[0] < values[-1]:
            values = list(values)
            colors = list(colors)
        elif values[0] > values[-1]:
            values = list(reversed(values))
            colors = list(reversed(colors))
        else:
            raise ValueError("ColorRamp values are equal")
        self._values = list(map(float, values))
        self._colors = [Color(color).rgb_float for color in colors]
        for i in range(len(self._values) - 1):
            if self._values[i] >= self._values[i + 1]:
                raise ValueError("ColorRamp values must be sorted and unique") 
[docs]    def getRGB(self, value):
        """
        Determine the color that corresponds to the specified value
        :param value: The value to calculate the color for
        :type value: int or float
        :return: The color corresponding to the specified value, where the color
            is a represented by a list of (red, green, blue) integers in the 0-255
            range.
        :rtype: list
        """
        new_color_f = self._getRGBFloat(value)
        return [int(round(i * 255)) for i in new_color_f] 
    def _getRGBFloat(self, value):
        """
        Determine the color (in float format) that corresponds to the specified
        value
        :param value: The value to calculate the color for
        :type value: int or float
        :return: The color corresponding to the specified value, where the color
            is a represented by a list of (red, green, blue) floats in the 0.0-1.0
            range.
        :rtype: list
        """
        for i, cur_val in enumerate(self._values):
            if value <= cur_val:
                val_index = i - 1
                break
        else:
            return self._colors[-1]
        if val_index == -1:
            return self._colors[0]
        prev_val = self._values[val_index]
        next_val = self._values[val_index + 1]
        scale = (value - prev_val) / (next_val - prev_val)
        new_color = [None] * 3
        for i in range(3):
            color1 = self._colors[val_index][i]
            color2 = self._colors[val_index + 1][i]
            new_color[i] = color1 + scale * (color2 - color1)
        return new_color 
[docs]class RainbowColorRamp(ColorRamp):
    # colors taken from mmshare/mmlibs/colorscheme/ramps/rainbow.rmp
    COLORS = ("red1", "user10", "user12", "user14", "user15", "user16",
              "user17", "user18", "user19", "user20", "user21", "green",
              "user53", "user54", "user55", "user56", "user57", "user58",
              "user59", "user60", "user26", "user28", "user30", "user32",
              "user61", "user62", "purple")
[docs]    def __init__(self, min_value=0, max_value=100):
        """
        :param min_value: The value corresponding to red.
        :type min_value: int
        :param max_value: The value corresponding to purple.
        :type max_value: int
        """
        step = (max_value - min_value) / len(self.COLORS)
        values = [min_value + i * step for i in range(len(self.COLORS))]
        super().__init__(self.COLORS, values)  
#############################################################################
## GLOBAL FUNCIONS
#############################################################################
def _load_scheme_dict():
    """
    Load the color scheme dictionary.
    Raises RuntimeError if Maestro installation is missing.
    Raises IOError if scheme.res file could not be found
    """
    global _scheme_dict
    global _scheme_name_dict
    scheme_res_file = _find_scheme_res_file()
    if scheme_res_file:
        scheme_dict, scheme_name_dict = _parse_scheme_files(scheme_res_file)
    else:
        raise IOError("Couldn't find 'scheme.res' file.")
    _scheme_dict = scheme_dict
    _scheme_name_dict = scheme_name_dict
[docs]def available_color_schemes():
    """
    Return a list of available color schemes (list of names).
    Raises RuntimeError if Maestro installation is not available.
    Raises IOError if scheme.res file could not be found
    """
    global _scheme_dict
    if _scheme_dict is None:
        _load_scheme_dict()
    return list(_scheme_dict) 
[docs]def get_color_scheme(name):
    """
    Return a ColorScheme object for scheme <name>.
    Raises ValueError if such scheme does not exist.
    Raises RuntimeError if Maestro installation is not available.
    Raises IOError if scheme.res file could not be found
    """
    global _scheme_dict
    global _scheme_name_dict
    if _scheme_dict is None:
        try:
            _load_scheme_dict()
        except RuntimeError as err:
            # Maestro installation is missing
            raise
        except Exception as err:
            # FIXME Why not simply raise the exception???
            print("ERROR:", err)
            _scheme_dict = {}
            _scheme_name_dict = {}
    if name in _scheme_name_dict:
        # A "long" name was specified; get the "short" name from it:
        name = _scheme_name_dict[name]
    scheme_instance = _scheme_dict.get(name)
    if not scheme_instance:
        raise ValueError('Invalid color scheme name: "%s"' % name)
    return scheme_instance 
[docs]def apply_color_scheme(st, scheme, atom_list=None):
    """
    Applies the scheme to the specified Structure <st>.
    Optionally a list of atom indecies may be specified.
    scheme
        One of the names returned by available_color_schemes() or a
        ColorScheme object returned by get_color_scheme().
    atom_list
        A list of atom indices to apply color scheme to (default all atoms).
    Raises ValueError if such scheme does not exist.
    Raises RuntimeError if Maestro installation is not available.
    Raises IOError if scheme.res file could not be found
    """
    if type(scheme) == type(""):
        scheme = get_color_scheme(scheme)
    scheme.apply(st, atom_list) 
#############################################################################
## MAIN CODE (run upon first import)
#############################################################################
def _find_scheme_res_file():
    """
    Will raise RuntimeError if Maestro installation is missing
    """
    # Locations to search for scheme.res:
    # .
    # <app_data>/maestro
    # <app_data>/maestroXY
    # MMSHARE_EXE/../../data
    local_loc = 'scheme.res'
    if os.path.isfile(local_loc):
        return local_loc
    try:
        appdata_dir = mm.mmfile_schrodinger_appdata_dir()
    except:
        raise RuntimeError("Could not determine the Schrodinger "
                           "application data directory.")
    MAESTRO_EXEC = os.environ.get('MAESTRO_EXEC')
    if not MAESTRO_EXEC:
        MAESTRO_EXEC = hunt('maestro')
    if MAESTRO_EXEC:
        # Check for the custom scheme.res file only if Maestro is installed:
        # Determine 2-digit version from MAESTRO_EXEC (i.e '75', '80'):
        mae_ver = MAESTRO_EXEC.split('maestro-v')[1][:2]
        maestro_ver_loc = os.path.join(appdata_dir, 'maestro%s' % mae_ver,
                                       'scheme.res')
        if os.path.isfile(maestro_ver_loc):
            return maestro_ver_loc
    # Open the scheme.res file from the distribution:
    builtin_loc = os.path.join(get_mmshare_data_dir(), 'scheme.res')
    if os.path.isfile(builtin_loc):
        return builtin_loc
    return None
def _parse_scheme_files(scheme_res_file):
    scheme_dict = {}  # key: scheme name; value: tuple of [color, asl]
    scheme_name_dict = {}  # key: scheme long name; value: scheme short name
    _schemes_dir = os.path.join(os.path.dirname(scheme_res_file), 'schemes')
    # Parse file of format:
    #   Line1: short name
    #   Line2: long name
    #   Line3: filename
    scheme_list = [[]]  # List of lists: (shortname, longname, filename,
    #                 custom)
    #Flag for custom lists:
    custom_list = False
    with open(scheme_res_file) as fh:
        for line in fh:
            ls = line.strip()
            if not ls or ls[0] == '#':
                continue
            if ls == "]":
                # End of custom list
                scheme_list[-1].append(custom_list)
                custom_list = False
                scheme_list.append([])
                continue
            if len(scheme_list[-1]) == 3:
                scheme_list[-1].append(custom_list)
                scheme_list.append([])
            scheme_list[-1].append(ls)
            if ls == "[":
                # Start of custom list
                scheme_list[-1] = []
                custom_list = True
    scheme_list[-1].append(custom_list)
    for shortname, longname, filename, is_custom in scheme_list:
        color_asl_list = []
        _read_scheme_rules(filename, _schemes_dir, color_asl_list, is_custom)
        scheme_instance = ColorScheme(shortname, color_asl_list)
        scheme_dict[shortname] = scheme_instance
        scheme_name_dict[longname] = shortname
    return scheme_dict, scheme_name_dict
def _read_scheme_rules(filename, schemes_dir, color_asl_list, is_custom):
    """
    Read the scheme rules from filename and place them in the list
    variable color_asl_list
    """
    got_custom_rule = False
    if is_custom:
        # Figure out the custom color name and ASL expression
        toks = []
        custom_fname = filename
        toks = custom_fname.split(":")
        if len(toks) == 3:
            sch_filename = toks[0]
            custom_color_name = toks[1]
            custom_asl_spec = toks[2]
            new_filename = os.path.join(schemes_dir, sch_filename)
            filename = new_filename
    filename = os.path.join(schemes_dir, filename)
    if not os.path.isfile(filename):
        print('color.py warning: Define scheme is missing file:')
        print('  ', filename)
        return
    with open(filename) as fh:
        for line in fh:
            if line[0] == '#':
                continue
            if not line.split():
                continue
            if 'LEGEND' in line:
                break
            if 'INCLUDE' in line:
                fname = line.split()[1]
                # Note here assume not CUSTOM schemes will be INCLUDED
                _read_scheme_rules(fname, schemes_dir, color_asl_list, False)
                continue
            if 'ADDCUSTOM' in line:
                got_custom_rule = True
            if line.startswith('DESCRIPTION'):
                break  # stop reading the file
            s = line.split('\t')
            color = None
            asl = None
            for item in s:
                if not item:
                    continue
                elif not color:
                    color = item
                elif not asl:
                    asl = item
                else:
                    continue
            if not color and not asl:
                raise ValueError(
                    'Invalid color ASL entry (must have 2 tabs): "%s"' % line)
                s = line.split("\t")
            color = color.replace('"', '')
            if got_custom_rule:
                color_asl_list.append((custom_color_name, custom_asl_spec))
                got_custom_rule = False
            else:
                color_asl_list.append((color, asl))