"""
General classes and functions related to output files, in particular
when running under jobcontrol.
"""
# Contributors: Mark A. Watson
import glob
import os
import string
from schrodinger.application.jaguar import output
from schrodinger.application.jaguar import utils as jag_utils
from schrodinger.application.jaguar.constants import JNAME
from schrodinger.application.jaguar.input import JaguarInput
from schrodinger.application.matsci import msutils
from schrodinger.infra import mmjob
from schrodinger.job import jobcontrol
from schrodinger.job.queue import FINISHED
from schrodinger.structure import StructureReader
from schrodinger.structure import StructureWriter
from schrodinger.utils import fileutils
# Characters permitted in automatically constructed names
LEGIT_CHARS = ''.join([string.ascii_letters, string.digits, '._-'])
_filelogger = None
[docs]class FileLoggerError(Exception):
    pass 
[docs]def register_file(fname, logfile=False):
    """
    Convenience wrapper for FileLogger.register_file() when
    using "with FileLogger:" context management.
    See FileLogger docstring for more details.
    """
    if _filelogger is not None:
        _filelogger.register_file(fname, logfile)
    else:
        msg = "Use 'with FileLogger:' to enable this wrapper."
        raise FileLoggerError(msg) 
[docs]class FileLogger(object):
    """
    Class to register output files. This is basically a wrapper for jobcontrol.
    """
[docs]    def __init__(self, jobname, do_recover):
        """
        :type  jobname: str
        :param jobname: jobname
        :type  do_recover: bool
        :param do_recover: if False, update the .recover file
        """
        self.recover_file = jobname + '.recover'
        self.do_recover = do_recover 
    def __enter__(self):
        """
        Support context management "with" statement.
        On entering the context, set a module variable so that we can
        conveniently use this class without passing a class instance around
        the backend scripts.
        """
        global _filelogger
        _filelogger = self
        return self
    def __exit__(self, exc_type, exc_value, exc_traceback):
        """
        Support context management "with" statement.
        """
        global _filelogger
        _filelogger = None
[docs]    def register_file(self, fname, logfile):
        """
        Register file in CWD as a jobcontrol output file.
        If the file already exists, this call will also request JC to copy it
        to the launch machine imminently, which is useful for returning
        intermediate results before the whole workflow ends.  e.g. such output
        files could be re-used in a recover/restart in case of job failure.
        To stream the file, use logfile=True. But note this doesn't work well
        where the change cannot simply be appended.
        :type  fname: str
        :param fname: file name
        :type  logfile: bool
        :param logfile: register as a streamed log file
        """
        relpath = os.path.normpath(os.path.join(relative_path(), fname))
        jobbe = jobcontrol.get_backend()
        if jobbe:
            # JAGUAR-9435: we distinguish between jobcontrol and jobserver here
            if mmjob.mmjob_is_job_server_job(jobbe.job_id):
                # LogFile and OutputFile are considered mutually exclusive file
                # categories by jobserver so the logic used for legacy
                # jobcontrol is invalid.
                if logfile:
                    # Register as a log file so that it can be streamed.
                    jobbe.addLogFile(relpath)
                elif os.path.exists(fname):
                    # Copy file from remote to launch machine and
                    # register it as an output file.
                    jobbe.copyOutputFile(relpath)
                else:
                    # Register as an output file to be returned at end of job.
                    jobbe.addOutputFile(relpath)
            else:
                # For legacy jobcontrol file streaming to work correctly, we
                # must overwrite any file with the same name currently in the
                # launch directory to ensure proper appending subsequently.
                # (See JAGUAR-9435 for more details).
                if logfile:
                    # Register as a log file so that it can be streamed.
                    jobbe.addLogFile(relpath)
                if os.path.exists(fname):
                    # Copy file from remote to launch machine and
                    # register it as an output file.
                    jobbe.copyOutputFile(relpath)
                elif not logfile:
                    # Register as an output file to be returned at end of job.
                    jobbe.addOutputFile(relpath)
        if not self.do_recover:
            # Update recovery file
            recover_file = os.path.join(launch_path(), self.recover_file)
            jag_utils.append_outfiles_to_recover_file(recover_file, [relpath])  
[docs]def launch_path():
    """
    Get the path from which the job was launched.
    if its a local job it will be the cwd.
    """
    jobbe = jobcontrol.get_backend()
    if jobbe:
        job = jobbe.getJob()
        # needs to be corrected to a "realpath"
        return os.path.realpath(job.JobDir)
    else:
        return os.getcwd() 
[docs]def relative_path():
    """
    Return the path to the CWD, relative to the current job's launch directory.
    If the CWD is the same as the launch directory, then the path
    is returned as an empty string (does not contain '.')
    """
    cwd = os.path.normpath(os.getcwd())
    launch_dir = os.path.normpath(launch_path())
    if cwd == launch_dir:
        relative_dir = ''
    else:
        relative_dir = os.path.relpath(cwd, start=launch_dir)
        # get rid of .. and .
        relative_dir = os.path.normpath(relative_dir)
    return relative_dir 
[docs]def set_structure_file(fname):
    """
    Register the file fname as the output structure file with jobcontrol, assumes the file fname,
    is in the cwd.
    """
    jobbe = jobcontrol.get_backend()
    if jobbe:
        rel_path = relative_path()
        jobbe.setStructureOutputFile(os.path.join(rel_path, fname)) 
[docs]def copy_file(fname):
    """
    Copy the file fname running under jobcontrol to the launch dir.
    """
    jobbe = jobcontrol.get_backend()
    if jobbe:
        rel_path = relative_path()
        jobbe.copyOutputFile(os.path.join(rel_path, fname)) 
[docs]def transfer_subjob_files(job_id):
    """
    Register files held in a job record from the working dir to
    the launch dir associated with a jobcontrol backend.
    This function can handle jobs launched in subdirectories.
    :type job_id: jobcontrol.Job.JobID
    :param job_id: jobcontrol job id
    """
    job = jobcontrol.Job(job_id)
    with fileutils.chdir(job.Dir):
        # We need to chdir to the job.Dir for JAGUAR-9553, but usually
        # the CWD and job.Dir are the same and this is a no-op.
        for ifile in job.getInputFiles():
            f = os.path.basename(ifile)
            register_file(f)
        for f in job.getOutputFiles():
            register_file(f)
        for f in job.LogFiles:
            register_file(f)
    stoutfile = job.StructureOutputFile
    if stoutfile:
        register_file(stoutfile) 
[docs]def slugify(mystr):
    """
    Transform a string to a valid file and job name
    """
    outstr = ''.join(c for c in mystr if c in LEGIT_CHARS)
    return outstr 
[docs]def make_outmaefile(outmaefile,
                    infiles,
                    status,
                    write_jname=False,
                    include_failures=False):
    """
    Collect output CTs from Jaguar jobs into a single .mae file
    :type  outmaefile: str
    :param outmaefile: name of output .mae file
    :type  infiles: list of strs
    :param infiles: subjob input files, including suffix, e.g. mol1.in
    :type  status: dictionary
    :param status: status of each subjob indexed by filename
    :type include_failures: bool
    :param include_failures: If True include failures in output maestro file
                             and group structures by status.  If False only
                             successful jobs are retained.
    """
    all_cts = StructureWriter(outmaefile)
    for infile in infiles:
        success = status.get(os.path.basename(infile), None) == FINISHED
        if include_failures or success:
            basename, ext = os.path.splitext(os.path.basename(infile))
            restart_name = output.restart_name(basename)
            maefile = restart_name + ".mae"
            ct = None
            if os.path.exists(maefile):
                ct = next(StructureReader(maefile))
            elif os.path.exists(infile):
                ct = JaguarInput(infile).getStructure()
            if ct is not None:
                # Adding this property is useful for mapping structures
                # to filenames, e.g. to create .smap files. (JAGUAR-6846)
                if write_jname:
                    ct.property[JNAME] = basename
                # group structures by status
                if include_failures:
                    msutils.set_project_group_hierarchy(ct, [
                        "job_status=%s" %
                        status.get(os.path.basename(infile), "N/A")
                    ])
                all_cts.append(ct)
    all_cts.close() 
[docs]def make_smapfile(outmaefile, smapfile):
    """
    Write a .smap file containing the associations between CT
    index numbers in the outmaefile, and .vib, .vis, .spm files.
    This function relies on CT's in the .mae file having a
    property 's_j_jname' (stored in JNAME) which maps to the name
    of the .vib etc file.
    """
    # unittested
    smap = open(smapfile, "w")
    smap.write('# smap version 1.0\n')
    smap.write('%s\n' % outmaefile)
    # It's useful to sort these for STU testing
    vib = sorted(glob.glob('*' + '.vib'))
    vis = sorted(glob.glob('*' + '.vis'))
    spm = sorted(glob.glob('*' + '.spm'))
    for datname in vib + vis + spm:
        # Find the longest match against the JNAME property.
        # e.g. test10_HOMO+1.vis will match against
        #      test10 in preference to test1. (JAGUAR-6846).
        ct_id = 0
        match = ''
        for i, ct in enumerate(StructureReader(outmaefile)):
            jname = ct.property.get(JNAME, None)
            if jname and jname in datname:
                if len(jname) > len(match):
                    match = jname
                    ct_id = i + 1
        if match and ct_id:
            smap.write('%s: %d\n' % (datname, ct_id))
    smap.write('#end\n')
    smap.close()