"""
Utilities to read and validate data from a CSV file containing antibody sequence
information.
"""
from contextlib import contextmanager
from dataclasses import dataclass
import csv
import enum
from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import Generator
from typing import List
from typing import Optional
from typing import Tuple
from schrodinger.utils import csv_unicode
from schrodinger.utils import fileutils
try:
from schrodinger.application.prime.packages import antibody
except ImportError:
antibody = None
NON_SEQ_HEADERS = (Headers.NAME, Headers.EXTRA_COLS)
SEQ_HEADERS = tuple(
header for header in Headers if header not in NON_SEQ_HEADERS)
[docs]class ModelingMode(enum.Enum):
"""
General modes of antibody modeling that can be performed.
"""
SINGLE_DOMAIN = enum.auto()
MONOSPECIFIC = enum.auto()
BISPECIFIC = enum.auto()
# ==============================================================================
# Errors
# ==============================================================================
[docs]class DataViolation(enum.Enum):
"""
Different kinds of invalid data that can exist in a given CSV file.
"""
NAME = 'invalid name'
HC_SEQ = 'invalid heavy chain sequence'
LC_SEQ = 'invalid light chain sequence'
NUM_COLS = 'invalid number of columns'
HEADERS = 'invalid headers'
[docs]class BaseInvalidAntibodyCSVError(Exception):
"""
Base class for exceptions relating to invalid antibody csv files. Subclasses
should define the message to display.
"""
[docs] def __init__(self, csv_file: str):
super().__init__()
self._csv_file = csv_file
[docs]class InvalidFileTypeError(BaseInvalidAntibodyCSVError):
"""
Exception to raise when the antibody csv file is the incorrect file type.
"""
def __str__(self):
return (f'Expected the supplied antibody file {self._csv_file} to be a '
f'.csv file')
[docs]class InvalidCSVLengthError(BaseInvalidAntibodyCSVError):
def __str__(self):
return (f'Expected the supplied antibody file {self._csv_file} to have '
'at least one non-header row.')
[docs]class InvalidRowError(Exception):
"""
Exception to raise when a particular row is invalid.
"""
[docs] def __init__(self, data_violations: List[DataViolation], row_num: int):
intro = (f'Row {row_num} of the supplied file contains invalid '
f'information: ')
data_violations = [violation.value for violation in data_violations]
data_violations_str = ', '.join(data_violations)
super().__init__(intro + data_violations_str)
# ==============================================================================
# Private Helpers
# ==============================================================================
def _csv_headers_are_valid(csv_headers: List[str],
valid_headers: List[Headers]) -> bool:
"""
Return whether the supplied csv headers match the order and spelling of the
valid headers. Note that the order is predefined by design in order to
minimize processing on our end.
:param csv_headers: The headers of a CSV file.
:param valid_headers: Header enums to compare with the CSV headers.
"""
cleaned_headers = [header.strip().upper() for header in csv_headers]
valid_headers = [header.value for header in valid_headers]
return cleaned_headers == valid_headers
def _get_name_violation(name: str) -> Optional[DataViolation]:
if not fileutils.is_valid_jobname(name):
return DataViolation.NAME
def _get_num_columns_violation(
extra_vals: Optional[List[str]]) -> Optional[DataViolation]:
"""
Return a number of columns data violation if there are any extra columns.
Note that having too few cells filled out in a particular row is not
considered a number of columns violation since those cells get filled with
a predetermined value that will be validated elsewhere.
"""
if extra_vals:
return DataViolation.NUM_COLS
def _get_heavy_chain_violation(heavy_seq: str) -> Optional[DataViolation]:
if not antibody.heavy_chain_seq_is_valid(heavy_seq):
return DataViolation.HC_SEQ
def _get_light_chain_violation(light_seq: str) -> Optional[DataViolation]:
if not antibody.light_chain_seq_is_valid(light_seq):
return DataViolation.LC_SEQ
# ==============================================================================
# Validators
# ==============================================================================
@dataclass(frozen=True)
class _AbstractAntibodyCSVValidator:
"""
A class the ability to validate row data from a CSV file with antibody
sequences. Instance variables are immutable.
Subclasses should never be instantiated directly. Instead, use
`get_reader()` to get a reader instance that is equipped with the correct
validator.
Subclasses must define only `modeling_mode` and `validators_by_header`.
"""
modeling_mode: ClassVar[ModelingMode] = NotImplemented
validators_by_header: ClassVar[Dict[Headers, Callable]] = NotImplemented
valid_headers: ClassVar[Tuple[Headers]] = None
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.valid_headers = tuple(
header for header in cls.validators_by_header.keys()
if header is not Headers.EXTRA_COLS)
def determine_data_violations(
self,
row_data: Dict[Headers, str],
validate_seqs: bool = False) -> List[DataViolation]:
"""
Return all data violations associated with the given row data.
:param row_data: The data contained in a single CSV row.
:param validate_seqs: Whether to also check the light and heavy chain seqs.
"""
violations = []
for header, validator in self.validators_by_header.items():
if not validate_seqs and header in SEQ_HEADERS:
continue
data = row_data.get(header)
if violation := validator(data):
violations.append(violation)
return violations
def validate_row(self,
row_data: Dict[Headers, str],
row_num: int,
validate_seqs: bool = False):
"""
Raise an error if any information in the supplied row is invalid.
"""
if violations := self.determine_data_violations(row_data,
validate_seqs):
raise InvalidRowError(violations, row_num)
class _SingleDomainAntibodyCSVValidator(_AbstractAntibodyCSVValidator):
modeling_mode = ModelingMode.SINGLE_DOMAIN
validators_by_header = {
Headers.EXTRA_COLS: _get_num_columns_violation,
Headers.NAME: _get_name_violation,
Headers.HC: _get_heavy_chain_violation,
}
class _MonospecificAntibodyCSVValidator(_AbstractAntibodyCSVValidator):
modeling_mode = ModelingMode.MONOSPECIFIC
validators_by_header = {
Headers.EXTRA_COLS: _get_num_columns_violation,
Headers.NAME: _get_name_violation,
Headers.HC: _get_heavy_chain_violation,
Headers.LC: _get_light_chain_violation
}
class _BispecificAntibodyCSVValidator(_AbstractAntibodyCSVValidator):
modeling_mode = ModelingMode.BISPECIFIC
validators_by_header = {
Headers.EXTRA_COLS: _get_num_columns_violation,
Headers.NAME: _get_name_violation,
Headers.HC1: _get_heavy_chain_violation,
Headers.LC1: _get_light_chain_violation,
Headers.HC2: _get_heavy_chain_violation,
Headers.LC2: _get_light_chain_violation
}
# TODO: add bispecifics back in BIOLUM-4664
_VALIDATORS = (_SingleDomainAntibodyCSVValidator(),
_MonospecificAntibodyCSVValidator())
class _AntibodyCSVReader(csv.DictReader):
"""
A CSV DictReader that uses an antibody CSV validator to potentially
validate the data in each row of the file.
The validator should be supplied by `get_reader()`.
"""
def __init__(self, validator, *args, **kwargs):
self._validator: _AbstractAntibodyCSVValidator = validator
super().__init__(*args,
fieldnames=self._validator.valid_headers,
restkey=Headers.EXTRA_COLS,
restval='',
**kwargs)
def validate_row(self, row_data: Dict[Headers, str], row_num: int,
validate_seqs: bool):
self._validator.validate_row(row_data, row_num, validate_seqs)
def determine_data_violations(
self, row_data: Dict[Headers, str]) -> List[DataViolation]:
return self._validator.determine_data_violations(row_data)
# ==============================================================================
# Public API
# ==============================================================================
[docs]@contextmanager
def get_reader(
csv_file: str) -> Optional[Generator[_AntibodyCSVReader, None, None]]:
"""
A context manager that returns an antibody CSV reader equipped with the
appropriate validator class given the supplied CSV file. Raises an error if
the given file does not meet the standards for proper parsing.
:param csv_file: A CSV file with antibody sequences.
"""
if not fileutils.is_csv_file(csv_file):
raise InvalidFileTypeError(csv_file)
with csv_unicode.reader_open(csv_file) as ab_csv_handle:
dict_reader = csv.DictReader(ab_csv_handle)
for validator in _VALIDATORS:
if _csv_headers_are_valid(dict_reader.fieldnames,
validator.valid_headers):
yield _AntibodyCSVReader(validator, f=ab_csv_handle)
return
# csv columns do not match any of the parser classes
raise InvalidCSVFormattingError(csv_file)