Source code for schrodinger.application.canvas.utils
# -*- coding: utf-8 -*-
"""
Canvas utility functions.
Copyright Schrodinger, LLC. All rights reserved.
"""
import csv
import os
import stat
from schrodinger.application.licensing.licerror import LicenseException
from schrodinger.infra.canvas import ChmLicenseShared
from schrodinger.utils import csv_unicode
(LICENSE_FULL, LICENSE_SHARED, LICENSE_FULL_OR_MAIN,
 LICENSE_MAIN) = ('LICENSE_FULL', 'LICENSE_SHARED', 'LICENSE_FULL_OR_MAIN',
                  'LICENSE_MAIN')
shared_lic = None
full_or_main_lic = None
full_lic = None
[docs]def strip_r(file_name):
    inputf = open(file_name, 'rb')
    original = inputf.read()
    inputf.close()
    converted = original.replace(b'\r', b'')
    os.chmod(file_name, stat.S_IRWXU)
    outputf = open(file_name, 'wb')
    outputf.write(converted)
    outputf.close()
[docs]class chm_excel(csv.excel):
    delimiter = ','
    quotechar = '"'
    doublequote = True
    skipinitialspace = False
    lineterminator = '\n'
    quoting = csv.QUOTE_MINIMAL
[docs]class chm_bluebird(csv.excel):
    delimiter = '\t'
    quotechar = '"'
    doublequote = True
    skipinitialspace = False
    lineterminator = '\t\n'
    quoting = csv.QUOTE_NONE
[docs]class ChmCSVTools:
[docs]    def __init__(
        self,
        default_format_in=None,
        default_format_out=None,
        case_sensitive=True,
    ):
        self.default_format_in = default_format_in
        self.default_format_out = default_format_out
        self.case_sensitive = False
        if not self.default_format_in:
            self.default_format_in = chm_excel()
        if not self.default_format_out:
            self.default_format_out = chm_bluebird()
[docs]    def dictValues(
        self,
        dict,
        key_list=[],  # noqa: M511
        action_remove=False,
    ):
        keeplist = []
        if action_remove:
            for (key, value) in dict.items():
                cmpkey = (key.lower() if isinstance(key, str) and
                          self.case_sensitive else key)
                found = False
                for h in key_list:
                    cmph = (h.lower() if not self.case_sensitive and
                            isinstance(h, str) else h)
                    if cmpkey == cmph:
                        found = True
                        break
                if not found:
                    keeplist.append(value)
        else:
            for (key, value) in dict.items():
                cmpkey = (key.lower() if isinstance(key, str) and
                          self.case_sensitive else key)
                found = False
                for h in key_list:
                    cmph = (h.lower() if not self.case_sensitive and
                            isinstance(h, str) else h)
                    if cmpkey == cmph:
                        keeplist.append(value)
                        break
        return keeplist
[docs]    def dictKeys(
        self,
        dict,
        key_list=[],  # noqa: M511
        action_remove=False,
    ):
        keeplist = []
        if action_remove:
            for (key, value) in dict.items():
                cmpkey = (key.lower() if isinstance(key, str) and
                          self.case_sensitive else key)
                found = False
                for h in key_list:
                    cmph = (h.lower() if not self.case_sensitive and
                            isinstance(h, str) else h)
                    if cmpkey == cmph:
                        found = True
                        break
                if not found:
                    keeplist.append(key)
        else:
            for (key, value) in dict.items():
                cmpkey = (key.lower() if isinstance(key, str) and
                          self.case_sensitive else key)
                found = False
                for h in key_list:
                    cmph = (h.lower() if not self.case_sensitive and
                            isinstance(h, str) else h)
                    if cmpkey == cmph:
                        keeplist.append(key)
                        break
        return keeplist
[docs]    def getHeader(
        self,
        source_file,
        dialect=None,
        one_based=False,
    ):
        header = {}
        dialect = (self.default_format_in if not dialect else dialect)
        with csv_unicode.reader_open(source_file) as source:
            rdr = csv.reader(source, dialect)
            pos = (0 if not one_based else 1)
            while True:
                try:
                    row = next(rdr)
                    for col in row:
                        header[col] = pos
                        pos = pos + 1
                    break
                except StopIteration:
                    pass
        return header
[docs]    def rewrite(
        self,
        source_file,
        dest_file,
        header_substitutions={},  # noqa: M511
        dialect=None,
    ):
        dialect = (self.default_format_in if not dialect else dialect)
        with csv_unicode.reader_open(source_file) as source, \
                csv_unicode.writer_open(dest_file) as dest:
            rdr = csv.reader(source, dialect)
            wtr = csv.writer(dest, dialect)
            try:
                row = next(rdr)
                newrow = []
                for col in row:
                    cmpcol = (col.lower() if not self.case_sensitive and
                              isinstance(col, str) else col)
                    found = False
                    for (key, value) in header_substitutions.items():
                        cmpkey = (key.lower() if not self.case_sensitive and
                                  isinstance(key, str) else key)
                        if cmpcol == cmpkey:
                            newrow.append(value)
                            found = True
                            break
                    if not found:
                        newrow.append(col)
                wtr.writerow(newrow)
            except StopIteration:
                pass
            while True:
                try:
                    wtr.writerow(next(rdr))
                except StopIteration:
                    break
[docs]    def rewriteByIndex(
        self,
        source_file,
        dest_file,
        column_indices=[],  # noqa: M511
        dialect_in=None,
        dialect_out=None,
        one_based=False,
        action_remove=False,
    ):
        dialect_in = \
            (self.default_format_in if not dialect_in else dialect_in)
        dialect_out = \
            (self.default_format_out if not dialect_out else dialect_out)
        offset = (0 if not one_based else 1)
        with csv_unicode.reader_open(source_file) as source, \
                csv_unicode.writer_open(dest_file) as dest:
            rdr = csv.reader(source, dialect_in)
            wtr = csv.writer(dest, dialect_out)
            if action_remove:
                for row in rdr:
                    newrow = []
                    for index in range(0, len(row)):
                        if index + offset not in column_indices:
                            newrow.append(row[index])
                    wtr.writerow(newrow)
            else:
                for row in rdr:
                    newrow = []
                    for index in column_indices:
                        newrow.append(row[index - offset])
                    wtr.writerow(newrow)
[docs]    def mergeByIndex(
        self,
        source_file1,
        source_file2,
        dest_file,
        column_indices_file1=[],  # noqa: M511
        column_indices_file2=[],  # noqa: M511
        dialect_in1=None,
        dialect_in2=None,
        dialect_out=None,
        one_based=False,
        action_remove=False,
    ):
        dialect_in1 = \
            (self.default_format_in if not dialect_in1 else dialect_in1)
        dialect_in2 = \
            (self.default_format_in if not dialect_in2 else dialect_in2)
        dialect_out = \
            (self.default_format_out if not dialect_out else dialect_out)
        with csv_unicode.reader_open(source_file1) as source1, \
                csv_unicode.reader_open(source_file2) as source2, \
                csv_unicode.writer_open(dest_file) as dest:
            rdr1 = csv.reader(source1, dialect_in1)
            rdr2 = csv.reader(source2, dialect_in2)
            wtr = csv.writer(dest, dialect_out)
            offset = (0 if not one_based else 1)
            while True:
                try:
                    r1 = next(rdr1)
                    r2 = next(rdr2)
                    if action_remove:
                        newrow = []
                        for index in range(0, len(r1)):
                            if index + offset not in column_indices_file1:
                                newrow.append(r1[index])
                        for index in range(0, len(r2)):
                            if index + offset not in column_indices_file2:
                                newrow.append(r2[index])
                        wtr.writerow(newrow)
                    else:
                        newr1 = []
                        newr2 = []
                        for c1 in column_indices_file1:
                            newr1.append(r1[c1 - offset])
                        for c2 in column_indices_file2:
                            newr2.append(r2[c2 - offset])
                        wtr.writerow(newr1 + newr2)
                except StopIteration:
                    break
[docs]    def mergeByName(
        self,
        source_file1,
        source_file2,
        dest_file,
        column_names1=[],  # noqa: M511
        column_names2=[],  # noqa: M511
        dialect_in1=None,
        dialect_in2=None,
        dialect_out=None,
        action_remove=False,
    ):
        one_based = False
        dialect_in1 = \
            (self.default_format_in if not dialect_in1 else dialect_in1)
        dialect_in2 = \
            (self.default_format_in if not dialect_in2 else dialect_in2)
        dialect_out = \
            (self.default_format_out if not dialect_out else dialect_out)
        column_indices_file1 = \
            self.dictValues(self.getHeader(source_file1, dialect_in1,
                            one_based), column_names1,
                            not action_remove)
        column_indices_file2 = \
            self.dictValues(self.getHeader(source_file2, dialect_in2,
                            one_based), column_names2,
                            not action_remove)
        return self.mergeByIndex(
            source_file1,
            source_file2,
            dest_file,
            column_indices_file1,
            column_indices_file2,
            dialect_in1,
            dialect_in2,
            dialect_out,
            one_based,
            action_remove,
        )
[docs]def get_license(license_type=LICENSE_FULL):
    """
    Instantiate a valid Canvas license object or raise an Exception.
    :type  license_type: A module-level constant: LICENSE_FULL,
            LICENSE_SHARED, or LICENSE_FULL_OR_MAIN. Default is LICENSE_FULL.
    :param license_type: The type of license to request.
    :raises: Exception when the license_type isn't recognized or the requested
            license is not valid.
    """
    #Due to CANVAS-4669, we allow all access to canvaslibs via python
    global shared_lic
    if not shared_lic:
        shared_lic = ChmLicenseShared(False)
        if not shared_lic.isValid():
            raise LicenseException(
                "The requested Canvas license (type '%s') is not present or not valid."
                % "ChmLicenseShared")
    return shared_lic