Source code for schrodinger.application.canvas.utils
# -*- coding: utf-8 -*-
"""
Canvas utility functions.
Copyright Schrodinger, LLC. All rights reserved.
"""
import csv
import os
import stat
from schrodinger.application.licensing.licerror import LicenseException
from schrodinger.infra.canvas import ChmLicenseShared
from schrodinger.utils import csv_unicode
(LICENSE_FULL, LICENSE_SHARED, LICENSE_FULL_OR_MAIN,
LICENSE_MAIN) = ('LICENSE_FULL', 'LICENSE_SHARED', 'LICENSE_FULL_OR_MAIN',
'LICENSE_MAIN')
shared_lic = None
full_or_main_lic = None
full_lic = None
[docs]def strip_r(file_name):
inputf = open(file_name, 'rb')
original = inputf.read()
inputf.close()
converted = original.replace(b'\r', b'')
os.chmod(file_name, stat.S_IRWXU)
outputf = open(file_name, 'wb')
outputf.write(converted)
outputf.close()
[docs]class chm_excel(csv.excel):
delimiter = ','
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\n'
quoting = csv.QUOTE_MINIMAL
[docs]class chm_bluebird(csv.excel):
delimiter = '\t'
quotechar = '"'
doublequote = True
skipinitialspace = False
lineterminator = '\t\n'
quoting = csv.QUOTE_NONE
[docs]class ChmCSVTools:
[docs] def __init__(
self,
default_format_in=None,
default_format_out=None,
case_sensitive=True,
):
self.default_format_in = default_format_in
self.default_format_out = default_format_out
self.case_sensitive = False
if not self.default_format_in:
self.default_format_in = chm_excel()
if not self.default_format_out:
self.default_format_out = chm_bluebird()
[docs] def dictValues(
self,
dict,
key_list=[], # noqa: M511
action_remove=False,
):
keeplist = []
if action_remove:
for (key, value) in dict.items():
cmpkey = (key.lower() if isinstance(key, str) and
self.case_sensitive else key)
found = False
for h in key_list:
cmph = (h.lower() if not self.case_sensitive and
isinstance(h, str) else h)
if cmpkey == cmph:
found = True
break
if not found:
keeplist.append(value)
else:
for (key, value) in dict.items():
cmpkey = (key.lower() if isinstance(key, str) and
self.case_sensitive else key)
found = False
for h in key_list:
cmph = (h.lower() if not self.case_sensitive and
isinstance(h, str) else h)
if cmpkey == cmph:
keeplist.append(value)
break
return keeplist
[docs] def dictKeys(
self,
dict,
key_list=[], # noqa: M511
action_remove=False,
):
keeplist = []
if action_remove:
for (key, value) in dict.items():
cmpkey = (key.lower() if isinstance(key, str) and
self.case_sensitive else key)
found = False
for h in key_list:
cmph = (h.lower() if not self.case_sensitive and
isinstance(h, str) else h)
if cmpkey == cmph:
found = True
break
if not found:
keeplist.append(key)
else:
for (key, value) in dict.items():
cmpkey = (key.lower() if isinstance(key, str) and
self.case_sensitive else key)
found = False
for h in key_list:
cmph = (h.lower() if not self.case_sensitive and
isinstance(h, str) else h)
if cmpkey == cmph:
keeplist.append(key)
break
return keeplist
[docs] def getHeader(
self,
source_file,
dialect=None,
one_based=False,
):
header = {}
dialect = (self.default_format_in if not dialect else dialect)
with csv_unicode.reader_open(source_file) as source:
rdr = csv.reader(source, dialect)
pos = (0 if not one_based else 1)
while True:
try:
row = next(rdr)
for col in row:
header[col] = pos
pos = pos + 1
break
except StopIteration:
pass
return header
[docs] def rewrite(
self,
source_file,
dest_file,
header_substitutions={}, # noqa: M511
dialect=None,
):
dialect = (self.default_format_in if not dialect else dialect)
with csv_unicode.reader_open(source_file) as source, \
csv_unicode.writer_open(dest_file) as dest:
rdr = csv.reader(source, dialect)
wtr = csv.writer(dest, dialect)
try:
row = next(rdr)
newrow = []
for col in row:
cmpcol = (col.lower() if not self.case_sensitive and
isinstance(col, str) else col)
found = False
for (key, value) in header_substitutions.items():
cmpkey = (key.lower() if not self.case_sensitive and
isinstance(key, str) else key)
if cmpcol == cmpkey:
newrow.append(value)
found = True
break
if not found:
newrow.append(col)
wtr.writerow(newrow)
except StopIteration:
pass
while True:
try:
wtr.writerow(next(rdr))
except StopIteration:
break
[docs] def rewriteByIndex(
self,
source_file,
dest_file,
column_indices=[], # noqa: M511
dialect_in=None,
dialect_out=None,
one_based=False,
action_remove=False,
):
dialect_in = \
(self.default_format_in if not dialect_in else dialect_in)
dialect_out = \
(self.default_format_out if not dialect_out else dialect_out)
offset = (0 if not one_based else 1)
with csv_unicode.reader_open(source_file) as source, \
csv_unicode.writer_open(dest_file) as dest:
rdr = csv.reader(source, dialect_in)
wtr = csv.writer(dest, dialect_out)
if action_remove:
for row in rdr:
newrow = []
for index in range(0, len(row)):
if index + offset not in column_indices:
newrow.append(row[index])
wtr.writerow(newrow)
else:
for row in rdr:
newrow = []
for index in column_indices:
newrow.append(row[index - offset])
wtr.writerow(newrow)
[docs] def mergeByIndex(
self,
source_file1,
source_file2,
dest_file,
column_indices_file1=[], # noqa: M511
column_indices_file2=[], # noqa: M511
dialect_in1=None,
dialect_in2=None,
dialect_out=None,
one_based=False,
action_remove=False,
):
dialect_in1 = \
(self.default_format_in if not dialect_in1 else dialect_in1)
dialect_in2 = \
(self.default_format_in if not dialect_in2 else dialect_in2)
dialect_out = \
(self.default_format_out if not dialect_out else dialect_out)
with csv_unicode.reader_open(source_file1) as source1, \
csv_unicode.reader_open(source_file2) as source2, \
csv_unicode.writer_open(dest_file) as dest:
rdr1 = csv.reader(source1, dialect_in1)
rdr2 = csv.reader(source2, dialect_in2)
wtr = csv.writer(dest, dialect_out)
offset = (0 if not one_based else 1)
while True:
try:
r1 = next(rdr1)
r2 = next(rdr2)
if action_remove:
newrow = []
for index in range(0, len(r1)):
if index + offset not in column_indices_file1:
newrow.append(r1[index])
for index in range(0, len(r2)):
if index + offset not in column_indices_file2:
newrow.append(r2[index])
wtr.writerow(newrow)
else:
newr1 = []
newr2 = []
for c1 in column_indices_file1:
newr1.append(r1[c1 - offset])
for c2 in column_indices_file2:
newr2.append(r2[c2 - offset])
wtr.writerow(newr1 + newr2)
except StopIteration:
break
[docs] def mergeByName(
self,
source_file1,
source_file2,
dest_file,
column_names1=[], # noqa: M511
column_names2=[], # noqa: M511
dialect_in1=None,
dialect_in2=None,
dialect_out=None,
action_remove=False,
):
one_based = False
dialect_in1 = \
(self.default_format_in if not dialect_in1 else dialect_in1)
dialect_in2 = \
(self.default_format_in if not dialect_in2 else dialect_in2)
dialect_out = \
(self.default_format_out if not dialect_out else dialect_out)
column_indices_file1 = \
self.dictValues(self.getHeader(source_file1, dialect_in1,
one_based), column_names1,
not action_remove)
column_indices_file2 = \
self.dictValues(self.getHeader(source_file2, dialect_in2,
one_based), column_names2,
not action_remove)
return self.mergeByIndex(
source_file1,
source_file2,
dest_file,
column_indices_file1,
column_indices_file2,
dialect_in1,
dialect_in2,
dialect_out,
one_based,
action_remove,
)
[docs]def get_license(license_type=LICENSE_FULL):
"""
Instantiate a valid Canvas license object or raise an Exception.
:type license_type: A module-level constant: LICENSE_FULL,
LICENSE_SHARED, or LICENSE_FULL_OR_MAIN. Default is LICENSE_FULL.
:param license_type: The type of license to request.
:raises: Exception when the license_type isn't recognized or the requested
license is not valid.
"""
#Due to CANVAS-4669, we allow all access to canvaslibs via python
global shared_lic
if not shared_lic:
shared_lic = ChmLicenseShared(False)
if not shared_lic.isValid():
raise LicenseException(
"The requested Canvas license (type '%s') is not present or not valid."
% "ChmLicenseShared")
return shared_lic