from ruamel import yaml
from voluptuous import All
from voluptuous import Any
from voluptuous import Range
from voluptuous import Required
from voluptuous import Schema
from schrodinger.application.vss.csvsmiles import CsvSmilesFile
_DEFAULT_N_REPORT = 100
_DEFAULT_TRAINING_TIME = 1
_MIN_TRAINING_TIME = 1 / 60 / 60
_MAX_TRAINING_TIME = 24
_N_REPORT = {
    Required('n_report', default=_DEFAULT_N_REPORT): All(int, Range(min=1))
}
_ML_COMMON = {
    **_N_REPORT,
    Required('training_time', default=_DEFAULT_TRAINING_TIME): All(
        float, Range(min=_MIN_TRAINING_TIME, max=_MAX_TRAINING_TIME))
}
[docs]class ShapeControl:
    NAME = 'shape'
    SCHEMA = Schema({
        'query': str,
        Required('shape_type'): Any('pharm', 'atom_color', 'atom_no_color'),
        **_N_REPORT,
    })
[docs]    def __init__(self, *, n_report, shape_type, query=None):
        self.n_report = n_report
        self.query = query
        self.shape_type = shape_type 
    @property
    def input_files(self):
        return [self.query] if self.query else [] 
[docs]class GlideControl:
    NAME = 'glide'
    SCHEMA = Schema({
        Required('grid'): str,
        **_N_REPORT,
    })
[docs]    def __init__(self, *, n_report, grid):
        self.n_report = n_report
        self.grid = grid 
    @property
    def input_files(self):
        return [self.grid] 
[docs]class GlideALControl:
    NAME = 'glide_al'
    SCHEMA = Schema({
        Required('grid'): str,
        **_ML_COMMON,
    })
[docs]    def __init__(self, *, n_report, grid, training_time):
        self.n_report = n_report
        self.grid = grid
        self.training_time = training_time 
    @property
    def input_files(self):
        return [self.grid] 
[docs]class LigandMLControl:
    NAME = 'ligand_ml'
    SCHEMA = Schema({**_ML_COMMON})
[docs]    def __init__(self, *, n_report, training_time):
        self.n_report = n_report
        self.training_time = training_time 
    @property
    def input_files(self):
        return [] 
[docs]class DiseControl:
    NAME = 'dise'
    SCHEMA = Schema({
        Required('seed', default=0.1): All(
            float, Range(min=0.0, max=1.0, min_included=False)),
        Required('similarity', default=0.5): All(
            float, Range(min=0.0, max=1.0, min_included=False)),
    })
[docs]    def __init__(self, *, seed, similarity):
        self.seed = seed
        self.similarity = similarity  
CONTROL_FILE_SCHEMA = Schema({
    'jobname': str,
    'databases': [str],
    Required('actives'): CsvSmilesFile.SCHEMA,
    'decoys': CsvSmilesFile.SCHEMA,
    DiseControl.NAME: DiseControl.SCHEMA,
    GlideControl.NAME: GlideControl.SCHEMA,
    ShapeControl.NAME: ShapeControl.SCHEMA,
    GlideALControl.NAME: GlideALControl.SCHEMA,
    LigandMLControl.NAME: LigandMLControl.SCHEMA,
})
[docs]class RunnerControl:
[docs]    def __init__(self,
                 *,
                 jobname=None,
                 databases=None,
                 dise=None,
                 actives=None,
                 decoys=None,
                 shape=None,
                 glide=None,
                 glide_al=None,
                 ligand_ml=None,
                 **kwargs):
        self.jobname = jobname
        self.databases = databases or []
        self.dise = dise
        self.actives = actives
        self.decoys = decoys
        self.shape = shape
        self.glide = glide
        self.glide_al = glide_al
        self.ligand_ml = ligand_ml 
    @property
    def input_files(self):
        files = []
        for name in ('actives', 'decoys', 'shape', 'glide', 'glide_al',
                     'ligand_ml'):
            spec = getattr(self, name)
            if spec:
                files += spec.input_files
        return files 
[docs]def get_control_from_dict(data):
    '''
    Instantiates `RunnerControl` from a
    dictionary that conforms to the `CONTROL_FILE_SCHEMA`.
    :param data: Control dictionary.
    :type data: dict
    :return: Run specification.
    :rtype: `RunnerControl`
    '''
    params = dict(data)
    for name in ('actives', 'decoys'):
        try:
            params[name] = CsvSmilesFile(**data[name])
        except KeyError:
            pass
    for cls in (DiseControl, ShapeControl, GlideControl, GlideALControl,
                LigandMLControl):
        name = cls.NAME
        try:
            params[name] = cls(**data[name])
        except KeyError:
            pass
    return RunnerControl(**params) 
[docs]def get_control_from_file(filename):
    '''
    Reads and parses a "control" file, validates schema, instantiates
    "run specification".
    :param filename: Filename.
    :type filename: str
    :return: Run specification.
    :rtype: `RunnerControl`
    '''
    with open(filename, 'r') as fp:
        data = yaml.safe_load(fp)
    return get_control_from_dict(CONTROL_FILE_SCHEMA(data))