"""
See compare_csv docstring
$Revision 0.2 $
@copyright: (c) Schrodinger, LLC. All rights reserved
"""
import numpy as np
import pandas
from schrodinger.utils import log
DEFAULT_TOLERANCE = 0.005
"""Default tolerance if none is provided."""
[docs]def help():
return "Usage: compare_csv(test.csv, reference.csv, [tolerance], [lines])"
[docs]def compare_csv(test_file,
ref_file,
tolerance=None,
reltol=None,
lines=None,
delimiter=None,
ignore_cols=None,
sort_by=None):
"""
Workup for comparing two CSV files.
Example use::
outcome_workup = compare_csv('test.csv', 'ref.csv', tolerance=0.05, lines=3, delimiter=' ')
Numeric values will be compared using tolerance or reltol, as described
below. Equality is required for values that cannot be cast as floats
(strings for example).
The lines in each CSV file are expected to line up (i.e. line 1 in test.csv
is compared with line 1 from ref.csv). This means if a line is skipped in
test.csv all subsequent lines will cause failures (so, many failure messages
will be printed - one for each line after the skip).
:param str test_file: Filename of csv to be tested.
:param str ref_file: Filename of reference csv.
:param float|None tolerance: Maximum possible deviation from the ref csv
for numeric values. If this and reltol are both None, a default
value will be used.
:param float|None reltol: Maximum possible deviation from the ref csv,
expressed as a relative value. For example if this is 0.02,
values may be different by up to 2%.
:param int lines: Number of lines to compare. Default is to compare all
lines and require that the same number of lines to be in the
reference and test files.
:param str delimiter: Delimiter to use while reading the csv. The default delimiter is ','.
:param list ignore_cols: List of column names to ignore
:param sort_by: Before comparing, sort the ref_file and test_file based on the column name(s) specified in sort_by.
:type sort_by: `list` of `str`
"""
if tolerance and reltol:
msg = 'Only one of tolerance and reltol can be used.'
raise AssertionError(msg)
if tolerance is None and reltol is None:
tolerance = DEFAULT_TOLERANCE
if not np.isreal(tolerance) and not np.isreal(reltol):
msg = "One of tolerance and reltol must be defined and numeric (found \"tolerance {}, reltol {}\")".format(
tolerance, reltol)
raise TypeError(msg)
tolerance_args = {
'atol': tolerance
} if tolerance is not None else {
'rtol': reltol
}
ref_df = pandas.read_csv(ref_file, delimiter=delimiter)
test_df = pandas.read_csv(test_file, delimiter=delimiter)
if ignore_cols:
# Remove columns without raising if a column doesn't exist
for df in (ref_df, test_df):
df.drop(columns=ignore_cols, inplace=True, errors='ignore')
violations = []
if not ref_df.shape[1] == test_df.shape[1]:
msg = 'Number of columns do not match. Reference: {}, Test: {}'
violations.append(msg.format(ref_df.shape[1], test_df.shape[1]))
elif lines is None and not ref_df.shape[0] == test_df.shape[0]:
msg = 'Number of rows do not match. Reference: {}, Test: {}'
violations.append(msg.format(ref_df.shape[0], test_df.shape[0]))
else:
sorted_ref_df = ref_df.sort_values(sort_by) if sort_by else ref_df
sorted_test_df = test_df.sort_values(sort_by) if sort_by else test_df
violations.extend(
compare_dfs(sorted_ref_df, sorted_test_df, tolerance_args, lines))
if violations:
log_name = 'workup_compare_csv.log'
log.logging_config(file=log_name, format='%(message)s', filemode='w')
logger = log.get_logger(log_name)
logger.warning(f'Errors comparing {test_file} and {ref_file}')
for violation in violations:
logger.warning(violation)
raise AssertionError("FAILURE: File %s was different from %s\n"
"Details can be found in %s" %
(test_file, ref_file, log_name))
return True
[docs]def compare_dfs(ref_df, test_df, tolerance_args, lines):
"""
Compare test and reference dfs using specified tolerance. Return a list
of any violations.
"""
violations = []
for (ref_index, ref_row), (test_index,
test_row) in zip(ref_df.iterrows(),
test_df.iterrows()):
if lines and ref_index >= lines:
break
for ref_val, test_val in zip(ref_row.tolist(), test_row.tolist()):
msg = None
if not np.isreal(ref_val) or not np.isreal(test_val):
if ref_val != test_val:
msg = 'Error in row {}: value {} does not match reference value {}.'
elif np.isnan(ref_val) and np.isnan(test_val):
# Special case since nans are not equal to each other.
pass
elif not np.isclose(ref_val, test_val, **tolerance_args):
msg = 'Error in row {}: value {} does not match reference value {} within tolerance'
if msg:
violations.append(msg.format(ref_index, ref_val, test_val))
return violations
if __name__ == "__main__":
import sys
assert len(sys.argv) == 3
ret = compare_csv(*sys.argv[1:])
if ret:
print('Workflow passed')