"""
This module provides an API for performing GPU-based fingerprint similarity
queries against compound databases on FPsim servers. Two variants of the API are
provided: a query function and the FPsimSearcher class.
The module-level query function can be used to query either all or a specific
subset of the public databses available on a single FPsim server.
The FPsimSearcher class allows multiple servers to be specified, each with
either all public databases or a specific subset. The entire collection of
databases may then be queried in a single call.
The server url should be the base url without the endpoint extension (e.g.
http://fpsim.foo.com/, *not* http://fpsim.foo.com/similarity_search)
"""
import collections
import functools
import urllib
from typing import Union
import pandas
import requests
from schrodinger.infra import licensing
DEFAULT = 'FPSIM_DEFAULT'
DEFAULT_MAX_RESULTS = 100
DEFAULT_SIMILARITY_CUTOFF = 0.5
CORP_ID_SPLITTER = ';:;'
SEARCH_ERROR_STRING = 'SERVER_ERROR_ON_SEARCH'
#===============================================================================
# Functional API
#===============================================================================
[docs]def query(smiles,
          url=DEFAULT,
          max_results=DEFAULT_MAX_RESULTS,
          similarity_cutoff=DEFAULT_SIMILARITY_CUTOFF,
          dbnames=DEFAULT):
    """
    Perform an FPSim query for a given smiles on a specific server.
    :param smiles: the query smiles to search against
    :param url: the base FPsim server URL
    :param max_results: max number of matches to return
    :param similarity_cutoff: minimum similarity cutoff for matches (0 - 1.0)
    :param dbnames: names of databases to search. Default: all public databases
    :return: a pandas.DataFrame of similar smiles with corresponding corporate
        ids and similarity scores.
    """
    if dbnames is DEFAULT:
        dbnames = get_public_databases(url)
    dbkeys = ['' for _ in dbnames]
    similarity_cutoff = _convert_cutoff_to_fraction(similarity_cutoff)
    data = {
        'smiles': smiles,
        'return_count': str(max_results),
        'similarity_cutoff': str(similarity_cutoff),
        'dbnames': ','.join(dbnames),
        'dbkeys': ','.join(dbkeys)
    }
    response = _raw_query(url, data)
    if not response.ok:
        raise ServerError(f'Bad server response: {url}')
    json_results = response.json()
    if json_results == 'Server error':
        raise QueryError(f'Problem with query: {data}')
    if (len(json_results) == 1 and json_results[0][0] == SEARCH_ERROR_STRING):
        raise SearchError(f'Problem with search: {smiles}')
    return _json_to_pandas(json_results) 
[docs]@functools.lru_cache()
def get_public_databases(url=DEFAULT):
    """
    Retrieves a list of database names for the public databases on an FPsim
    server.
    :param url: the base FPsim server url. By default this will be the
        Schrodinger FPsim server.
    :return: a list of database names, which may be used as the dbnames
        parameter to a query.
    :raises requests.exception.ConnectionTimeout: when the get request, which
        should be instantaneous (up to network delays), takes longer than 3s.
    """
    if url is DEFAULT:
        # Also performs license check
        url = licensing.gpusim_current_url(
            licensing.SimilarityEndpoint.DATABASES)
        timeout = None
    else:
        licensing.licenseExists(licensing.GPUSIMILARITY)
        url = urllib.parse.urljoin(url, 'dbnames')
        timeout = 3
    response = requests.get(url, verify=True, timeout=timeout)
    if not response.ok:
        raise ServerError(f"Can't fetch db list from {url}")
    dbnames = response.json()
    for idx, name in enumerate(dbnames):
        if name.endswith('.fsim'):
            dbnames[idx] = name.rstrip('.fsim')
    return dbnames 
#===============================================================================
# Searcher API
#===============================================================================
[docs]class FPsimSearcher:
    """
    A search context for FPsim queries. Specify a collection of data sources on
    the object to perform multiple queries on the same servers and databases.
    """
[docs]    def __init__(self, url=DEFAULT, dbnames=DEFAULT):
        self._server_dbs = collections.defaultdict(set)
        if url is not None:
            self.addSource(url, dbnames) 
[docs]    def addSource(self, url=DEFAULT, dbnames=DEFAULT):
        """
        Add databases to the searcher specified by URL and database name(s).
        :param url: the base FPsim server url. By default this will be the
            Schrodinger FPsim server.
        :param dbnames: the database name to add for the specified server. This
            may be a single name or a list of names.
        """
        if dbnames is DEFAULT:
            dbnames = get_public_databases(url)
        if isinstance(dbnames, str):
            dbnames = set(dbnames)
        if url:
            self._server_dbs[url].update(dbnames) 
[docs]    def getSources(self):
        """
        Gets all the sources that have been added to this searcher.
        :return: a dictionary of {url:dbnames}.
        """
        sources = {}
        for url, dbnames in self._server_dbs.items():
            if url is DEFAULT:
                url = _get_default_url()
            sources[url] = dbnames
        return sources 
[docs]    def query(self,
              smiles,
              max_results=DEFAULT_MAX_RESULTS,
              similarity_cutoff=DEFAULT_SIMILARITY_CUTOFF):
        """
        Perform a query on all the sources in this searcher. Each server will be
        queried in the order it was added via addSource until max_results
        results have been found. The results from each server are simply
        concatenated, with no de-duplication of results.
        See module function query() for parameter docs.
        """
        results = _json_to_pandas([])
        for url, dbnames in self._server_dbs.items():
            df = query(smiles,
                       url,
                       max_results=max_results,
                       similarity_cutoff=similarity_cutoff,
                       dbnames=dbnames)
            results = pandas.concat([results, df])
            max_results = max_results - len(df)
            if max_results < 1:
                break
        return results  
#===============================================================================
# Exception classes
#===============================================================================
[docs]class ServerError(RuntimeError):
    pass 
[docs]class QueryError(RuntimeError):
    pass 
[docs]class SearchError(RuntimeError):
    pass 
#===============================================================================
# Utility functions
#===============================================================================
def _json_to_pandas(json_results):
    df = pandas.DataFrame(json_results,
                          columns=['corp_ids', 'smiles', 'similarity'])
    df['corp_ids'] = df['corp_ids'].apply(_parse_corp_id_str)
    return df
def _parse_corp_id_str(corp_id_str):
    return corp_id_str.split(CORP_ID_SPLITTER)
def _get_default_url():
    return licensing.gpusim_current_url(licensing.SimilarityEndpoint.FASTSIM)
def _raw_query(url, data):
    if url is DEFAULT:
        url = _get_default_url()
    else:
        url = urllib.parse.urljoin(url, 'similarity_search_json')
    return requests.post(url, data, verify=True)
def _convert_cutoff_to_fraction(cutoff: Union[int, float]) -> float:
    """
    Return a similarity cutoff between 0 and 1. Converts a percentage
    cutoff to a fraction. A fractional cutoff is returned untouched if
    already in the range [0.0, 1.0].
    This function is required in order to always pass a fractional similarity
    cutoff argument into the FPSim search while preserving compatibility with
    the previous convention of accepting a percentage.
    :param cutoff: minimum similarity cutoff for matches. Assumed to be a number
        greater than or equal to 0.
    :return: A similarity cutoff between 0 and 1.
    """
    # Account for floating point noise
    epsilon = 1.e-15
    if 0 <= cutoff <= 1.0 + epsilon:
        return min(cutoff, 1.0)
    return min(cutoff, 100) / 100