"""
Functions to parse MSJ files
Copyright Schrodinger, LLC. All rights reserved.
"""
import contextlib
from typing import Tuple
from schrodinger.application.desmond import ana
from schrodinger.utils import sea
def _canonicalize_key(key: str) -> str:
    key = key.strip()
    if key[0] == "[":
        return "stage" + key
    if key[:5] != "stage":
        return "stage[i]." + key
    return key
# FIXME: Refactor `ana.ArkDb` into a more fundamental module.
[docs]class Msj(ana.ArkDb):
    """
    The use of this class is very similar to `ana.ArkDb`. Some terminologies and
    syntax defined in `ana.ArkDb` are not repeated here. So if you are unfamiliar
    with `ana.ArkDb`, you are encouraged to read the docstrings there. Below are
    examples to demonstrate and explain the use of this class::
        msj = parse(msj_fname)
    # 1. To read the `time` setting in stage 3:
    msj.get("stage[2].simulate.time")
    msj.get("stage[-1].simulate.time")
    # - `stage[2]` in the key corresponds to stage 3 because the stage index in
    #   the code is always zero based.
    # - Negative stage indices are supported and are of the same meaning as in
    #   Python.
    # - The stage type (here it is `simulate`) is always needed in the key. This
    #   is good for preventing some kind of index errors, and is very useful
    #   when you do NOT care about the index (see below).
    # 2. To read the `time` setting in the first `lambda_hopping` stage:
    msj.get("stage[i].lambda_hopping.time")
    # 3. Default value for non-existing setting:
    msj.get("stage[i].lambda_hopping.phony_setting", 1000)
    # 4. To change the value of a setting:
    msj.put("stage[i].lambda_hopping.time", 1000)
    # 5. To append a new stage `fep_analysis` and use the stage's default
    #    settings:
    msj.put("stage[$].fep_analysis", {})
    # 6. To insert a new stage `build_geometry` to be the 2nd stage:
    msj.put("stage[@]1.build_geometry", {})
    # 7. To insert a new stage `assign_forcefield` with custom settings::
        msj.put("stage[@]2.assign_forcefield",
            sea.Map(```
                hydrogen_mass_repartition = off
                make_alchemical_water = on```))
    # Note that it's necessary to convert the settings in the form of a string
    # into a `sea.Map`.
    # 8. To delete the first "simulate" stage:
    msj.delete("stage[i].simulate")
    # 9. To delete the "simulate" stage whose title reads "production":
    msj.delete("stage[i].simulate", matches="title=production")
    # 10. To delete all "simulate" stages:
    msj.delete("stage[*].simulate")
    # 11. To access a particular stage, use the syntax:
    #     `msj.stage[<index>].<stage-name>.<setting>`:
    msj.stage[10].simulate.time.val += 1000
    # 12. Support for shorthand keys:
    msj.put("lambda_hopping.phony_setting", 2000)
    msj.get("lambda_hopping.phony_setting")
    msj.put("[@]1.build_geometry", {})
    # - If the key starts with the stage type name, it will be assumed to be
    #   prefixed with 'stage[i].'. In other words, the first instance of the
    #   stage type will be operated on.
    # - The key can start with '[i]', '[*]', '[@]', and '[$]', and it will be
    #   automatically prefixed with 'stage'.
    # 13. Find the index of the first "simulate" stage:
    first_simulate_index = msj.find_stage("simulate")[0].STAGE_INDEX
    # - `find_stage` always returns a tuple of the found stages, that's why
    #   `[0]` is used to get the first and the only found "simulate" stage.
    # - `STAGE_INDEX` gives the stage's index in the MSJ file. Note that the
    #   index is one-based.
    # 13. Find the index of the last "trim" stage:
    last_trim_index = msj.find_stage("[*].trim")[-1].STAGE_INDEX
    # - Here we use the key "[*].trim" to find all "trim" stages and then select
    #   the last one with `[-1]`.
    # - We don't use the key "trim" (remember it's a short hand of "[i].trim"),
    #   because it means to get the first "trim" from the beginning of the stage
    #   list.
    # 14. Find the index of the second "simulate" stage:
    second_simulate_index = msj.find_stage("[*].simulate")[1].STAGE_INDEX
    # 15. Find the index of the "simulate" stage whose "title" parameter is set
    #     to "production".::
        production_index = msj.find_stage("[*].simulate.title"
        picker=lambda title: (title.parent() if title.val == "production"
                              else None)
        )[0].STAGE_INDEX
    # - With the key "[*].simulate.title", we find the titles of all simulate
    #   stages, then we use a lambda function as the `picker` to select the
    #   stage whose title is "production". Note that the "parent" of the "title"
    #   parameter is the stage, which is what `title.parent()` gives.
    (Feel free to add more examples)
    """
[docs]    def __init__(self, *args, **kwargs):
        ana.ArkDb.__init__(self, *args, **kwargs)
        self._reset_stage_indices() 
    @contextlib.contextmanager
    def _ensure_valid_stage_indices(self):
        num_stages_before = len(self.stage)
        try:
            yield num_stages_before
        finally:
            if num_stages_before != len(self.stage):
                self._reset_stage_indices()
    def _reset_stage_indices(self):
        for i, stg in enumerate(self.stage):
            # `stg` has one and only one key-value pair.
            stg_param = stg.values()[0]
            setattr(stg_param, "STAGE_INDEX", i + 1)
    def __str__(self):
        s = []
        for stage in self._db.stage:
            # `stage` (`sea.Map`) has a single key-value pair.
            ((stg_name, stg_setting),) = stage.key_value()
            s.append("%s {\n%s}\n" % (stg_name, stg_setting.__str__(ind="  ")))
        return "\n".join(s)
    @property
    def stage(self):
        return self._db.stage
[docs]    def get(self, key: str, *args, **kwargs):
        return ana.ArkDb.get(self, _canonicalize_key(key), *args, **kwargs) 
[docs]    def put(self, key: str, *args, **kwargs):
        with self._ensure_valid_stage_indices():
            return ana.ArkDb.put(self, _canonicalize_key(key), *args, **kwargs) 
[docs]    def delete(self, key: str, *args, **kwargs):
        with self._ensure_valid_stage_indices():
            ana.ArkDb.delete(self, _canonicalize_key(key), *args, **kwargs)
            # When we delete a stage like this: `msj.delete("[*].simulate")`,
            # `ArkDb.delete` will remove all "simulate" key-value pairs, but
            # leave empty `sea.Map` objects in the `self.stage` list. We need to
            # delete these empty objects.
            empty_map_indices = []
            for i, stage in enumerate(self._db.stage):
                if len(stage.values()) == 0:
                    empty_map_indices.append(i)
            for i in reversed(empty_map_indices):
                del self._db.stage[i] 
[docs]    def find(self, key: str, *args, **kwargs):
        return ana.ArkDb.find(self, _canonicalize_key(key), *args, **kwargs) 
[docs]    def find_stages(self, key: str, *args, **kwargs) -> Tuple[sea.Map]:
        """
        Similar to `find`, but to return a tuple of found stages. If no stages
        are found, this function returns an empty tuple. Examples::
            # To get all simulate stages:
            simulate_stages = msj.find_stages("[*].simulate")
            for stage in simulate_stages:
                print(stage.STAGE_INDEX)
            # To get the first simulate stage:
            first_simulate_stage = msj.find_stages("simulate")[0]
            print(first_simulate_stage.STAGE_INDEX)
            # To get the first simulate stage with "time = 200"::
                simulate = msj.find_stages(
                    "[*].simulate.time", picker=\
                    lambda x: (x.parent() if x.val == 200 else None))[0]
                print(simulate.STAGE_INDEX)
        `STAGE_INDEX` gives the index of the stage in the MSJ file. Note that
        the index is one based. Also, `STAGE_INDEX` is an attribute (as opposed
        to a key-value pair) of the returned `sea.Map` objects.
        """
        return tuple(arkdb._db for arkdb in self.find(key, *args, **kwargs))  
def _pre_parse(string: str) -> str:
    msj = sea.Map("stage = [ %s ]" % string).stage
    # User might set a stage like: "stagename = {...}" by mistake.
    # Raises an exception for this type of errors.
    for s in msj:
        if isinstance(s, sea.Atom) and s.val == "=":
            raise SyntaxError(
                "Stage name must not be followed by the assignment symbol: '='")
    s = [("{%s = {%s}}" % (stage_name, stage_settings)) \
         for stage_name, stage_settings in zip(msj[::2], msj[1::2])]
    return "stage = [%s]" % "\n".join(s)
[docs]def parse(fname=None, string=None) -> Msj:
    """
    Parses a file or a string, and returns an `Msj` object.
    Either `fname` or `string` must be set, but not both.
    """
    assert bool(fname) ^ bool(string), \
        
"Either `fname` and `string` must be specified."
    if fname:
        with open(fname) as fh:
            string = fh.read()
    return Msj(string=_pre_parse(string))