"""
Container for common test functions used in Canvas tests
"""
import csv
import os
from itertools import zip_longest
from schrodinger.utils import csv_unicode
from schrodinger.utils import subprocess
[docs]def read_csv(filename):
    """Read a canvasMCS csv formatted file. """
    with csv_unicode.reader_open(filename) as csvfile:
        reader = csv.DictReader(csvfile)
        yield from reader 
[docs]def read_pw(filename):
    """Read a canvasMCS pw formatted file. (pw stands for pair-wise)."""
    with csv_unicode.reader_open(filename) as csvfile:
        reader = csv.DictReader(csvfile)
        yield from reader 
[docs]def run_mcs_pw(inputfile, atom_typing=97):
    """Run canvasMCS -atomtype 97 on `inputfile`, produces pairwise output."""
    outputfile = inputfile.replace('.mae', '.pw')
    command = [
        'canvasMCS', '-imae', inputfile, '-opw', outputfile, '-atomtype',
        str(atom_typing), '-silent'
    ]
    try:
        subprocess.check_output(command, universal_newlines=True)
    except subprocess.CalledProcessError as err:
        print(' '.join(command), 'failed with exit code', err.returncode)
        print(err.output)
        raise
    return outputfile 
[docs]def run_mcs(cmd):
    """
    Run a command and check return code and that the phrase
    "successfully completed" is printed.
    :return: stdout from the canvasMCS job.
    """
    try:
        output = subprocess.check_output(cmd, universal_newlines=True)
    except subprocess.CalledProcessError as err:
        print(' '.join(cmd), 'failed with exit code', err.returncode)
        print(err.output)
        raise
    except OSError as err:
        if err.errno == 2:
            raise OSError(2,
                          'No such file or directory: {}'.format(' '.join(cmd)))
    if 'successfully completed' not in output:
        raise AssertionError("")
    return output 
[docs]def get_map(list1, list2):
    """
    Get a map of elements in list1 to elements in list2. Fails if the mapping
    is not one to one. list1 and list2 are strings of comma separated integers.
    """
    list1 = [int(i) for i in list1.split(',')] if list1 else []
    list2 = [int(i) for i in list2.split(',')] if list2 else []
    if len(list1) != len(list2):
        raise ValueError('Lists are not the same length! '
                         '({} != {})'.format(len(list1), len(list2)))
    mapping = dict(list(zip(list1, list2)))
    if len(mapping) != len(list1):
        raise ValueError('At least one item from list1 is repeated. '
                         '{}: {}'.format(list1, list2))
    return mapping 
[docs]def assertAtomMapsEqual(left_map, left_filename, right_map, right_filename):
    """
    Check that two dicts representing atom mappings are equal.
    Functionally equivalent to assertEqual(left_map, right_map), but more
    precise error reporting.
    """
    diffs = []
    d1, d2 = {}, {}
    inset = len(left_filename)
    for k in sorted(set(left_map) | set(right_map)):
        if left_map.get(k, None) != right_map.get(k, None):
            left_data = '{k}->{v1}'.format(k=k, v1=left_map.get(k, None))
            diffs.append('{left_data:>{inset}} != {k}->{v2}'.format(
                k=k,
                left_data=left_data,
                inset=inset,
                v2=right_map.get(k, None)))
            d1[k] = left_map.get(k, None)
            d2[k] = right_map.get(k, None)
    if diffs:
        raise AssertionError(
            "Atom maps didn't match:\n" +
            "{} != {}\n".format(left_filename, right_filename) +
            '\n'.join(diffs)) 
[docs]def assertEqual(a, b):
    if a != b:
        raise AssertionError("Items did not match:\n" + a + " " + b) 
[docs]def assertCSVFilesMatch(test_file, reference_file):
    """
    This function considers only the first two records in each file and
    checks that the pairwise atom mappings are consistent with the reference
    The current fields in a MCS csv file are:
    SMILES
    Name
    i_canvas_MCS_Match_Count
    i_canvas_MCS_Size
    i_canvas_MCS_Group
    i_canvas_MCS_Atom_Count
    i_canvas_MCS_Bond_Count
    s_canvas_MCS_Atom_List
    s_canvas_MCS_Bond_List
    s_canvas_MCS_SMARTS
    This checks that the SMILES, s_canvas_MCS_Atom_List and fields match.
    Uses assertAtomMapsEqual to check that the mapping is the same on the
    LHS and the RHS.
    """
    test_data = read_csv(test_file)
    reference_data = read_csv(reference_file)
    test_atom_mappings = []
    reference_atom_mappings = []
    for test, reference in zip_longest(test_data, reference_data):
        test_atom_mappings.append(test['s_canvas_MCS_Atom_List'])
        reference_atom_mappings.append(test['s_canvas_MCS_Atom_List'])
    #note that only the first two records of each file are considered
    assertAtomMapsEqual(
        get_map(test_atom_mappings[0], test_atom_mappings[1]),
        os.path.basename(test_file),
        get_map(reference_atom_mappings[0], reference_atom_mappings[1]),
        os.path.basename(reference_file))