Source code for schrodinger.application.phase.packages.shape_diversity

"""
Provides functionality for selecting diverse structures from shape
screen hits.

Copyright Schrodinger LLC, All Rights Reserved.
"""

import sqlite3
import time

from schrodinger import structure
from schrodinger.application.combinatorial_diversity import diversity_selector
from schrodinger.application.combinatorial_diversity import diversity_splitter
from schrodinger.application.combinatorial_diversity import driver_utils
from schrodinger.infra import canvas
from schrodinger.infra import phase

# Upper limit on the fraction of hits that may be selected by diversity:
MAX_DIVERSE_FRACTION = 0.2

# Frequency for issuing progress messages when generating fingerprints:
FP_PROGRESS_INTERVAL = 10000

# Name of main data table in fingerprint files:
FP_DATA_TABLE = 'DATA'

# The target number of diverse structures to select from each chunk:
TARGET_DIVERSE_PER_CHUNK = 500

# Maximum number of diversity optimization cycles:
MAX_OPT_CYCLES = 50


[docs]def generate_shape_gpu_fingerprints(hits_file, fp_file_prefix, logger=None, progress_interval=FP_PROGRESS_INTERVAL): """ Generates molprint2D fingerprints for shape_screen_gpu hits provided in Maestro or SD format. A separate fingerprint file <fp_file_prefix>_<i>.fp is generated from the hits produced by each shape query, where <i> runs from 1 to the number of shape queries. Query structures themselves are skipped, and the name stored for each row in a given fingerprint file corresponds to the 0-based position of the structure in the hits file. :param hits_file: Name of shape_screen_gpu hits file :type hits_file: str :param fp_file_prefix: Prefix of output fingerprint files :type fp_file_prefix: str :param logger: Logger for info level progress messages :type logger: logging.Logger or NoneType :param progress_interval: Interval between progress messages :type progress_interval: int :return: List of fp_name, fp_count tuples :rtype: list((str, int)) """ adaptor = canvas.ChmMmctAdaptor() stereo = canvas.ChmMmctAdaptor.NoStereo # Not needed for molprint2D. # The diversity selection code expects a SMILES column, but its contents # are irrelevant for our purposes. extra_data = {'SMILES': 'dummy'} query_count = 0 total_fp_count = 0 fp_tuples = [] fp_file = '' fp_generator = None if logger: t1 = time.process_time() with structure.StructureReader(hits_file) as reader: # hits_file is organized as follows: # Shape query 1 # Hit 1 for shape query 1 # Hit 2 for shape query 1 # etc. # Shape query 2 # Hit 1 for shape query 2 # Hit 2 for shape query 2 # etc. query_fp_count = 0 for hit_pos, st in enumerate(reader): if phase.PHASE_SHAPE_SIM not in st.property: # It's a shape query structure. query_count += 1 if query_count > 1: fp_generator.close() fp_tuples.append((fp_file, query_fp_count)) fp_file = f'{fp_file_prefix}_{query_count}.fp' if logger: logger.info(f'Writing fingerprints to {fp_file}') fp_generator = canvas.ChmMolprint2D32() fp_generator.open(fp_file) query_fp_count = 0 else: # It's a hit. name = str(hit_pos) try: fp_generator.write(adaptor.create(st, stereo), name, extra_data) total_fp_count += 1 query_fp_count += 1 if logger and total_fp_count % progress_interval == 0: msg = f'{total_fp_count:,} fingerprints generated' logger.info(msg) except canvas.ChmException as err: if logger: logger.info(f'{str(err)} - skipping structure') if fp_generator: fp_generator.close() fp_tuples.append((fp_file, query_fp_count)) if logger: logger.info(f'Number of fingerprint files created: {len(fp_tuples)}') msg = f'Total number of fingerprints generated: {total_fp_count:,}' logger.info(msg) t2 = time.process_time() logger.info('CPU time: %.2f sec' % (t2 - t1)) return fp_tuples
[docs]def get_min_pop(diverse_fraction): """ Given the diverse fraction of hits to select, this function determines an appropriate minimum population for each chunk of hit space. Returns driver_utils.DEFAULT_MIN_POP if diverse_fraction is 0.025 or smaller. Otherwise, the minimum population is halved until the maximum number of diverse structures per chunk would not exceed TARGET_DIVERSE_PER_CHUNK. The returned value corresponds to the combinatorial_diversity -min_pop parameter. :param diverse_fraction: Diverse fraction of hits to select :type diverse_fraction: float :return: Minimum number of hits per chunk :rtype: int :raise: ValueError if diverse_fraction is outside the legal range """ if diverse_fraction <= 0.0 or diverse_fraction > MAX_DIVERSE_FRACTION: msg = ('diverse_fraction must lie within the interval ' f'(0.0, {MAX_DIVERSE_FRACTION}]') raise ValueError(msg) min_pop = driver_utils.DEFAULT_MIN_POP if diverse_fraction <= 0.025: return min_pop while True: min_pop = int(min_pop / 2) max_diverse_per_chunk = 2 * int(round(diverse_fraction * min_pop)) if max_diverse_per_chunk <= TARGET_DIVERSE_PER_CHUNK: return min_pop
[docs]def get_num_probes(num_hits, min_pop): """ Determines an appropriate number of hit space probes for the specified number of hits and minimum population per hit space chunk. The number of probes will be driver_utils.DEFAULT_NUM_PROBES unless additional probes are needed to ensure that min_pop * 2**(num_probes - 1) is at least num_hits. The returned value corresponds to the combinatorial_diversity -ndim parameter. :param num_hits: Total number of hits :type num_hits: int :param min_pop: Minimum number of hits per chunk :type min_pop: int :return: Number of probes :rtype: int """ num_probes = driver_utils.DEFAULT_NUM_PROBES hit_capacity = min_pop * 2**(num_probes - 1) if hit_capacity >= num_hits: return num_probes while True: num_probes += 1 hit_capacity *= 2 if hit_capacity >= num_hits: return num_probes
[docs]def get_shape_gpu_hits_positions(fp_file, fp_positions): """ Given a fingerprint file created by generate_shape_gpu_fingerprints and a list of 0-based positions in that file, this function returns the corresponding 0-based positions in the hits file from which the fingerprint file was created. Accounts for the presence of shape queries in the hits file, grouping of hits by shape query and any hits for which fingerprint generation failed. :param fp_file: Name of fingerprint file :type fp_file: str :param fp_positions: 0-based positions in the fingerprint file :type fp_positions: Any iterable of int values :return: 0-based positions into the hits file :rtype: list(int) """ ids = '(' + ','.join(str(pos + 1) for pos in fp_positions) + ')' conn = sqlite3.connect(fp_file) cursor = conn.cursor() select = f'SELECT name FROM {FP_DATA_TABLE} WHERE id IN {ids}' hits_positions = [int(row[0]) for row in cursor.execute(select)] conn.close() return hits_positions
[docs]def select_shape_gpu_hits(hits_file_in, diverse_fraction, hits_file_out, fp_file_prefix, logger=None, progress_interval=FP_PROGRESS_INTERVAL): """ Selects a specified fraction of structurally diverse hits from a shape_screen_gpu hits file and writes a new hits file containing the shape queries and only the diverse hits for each shape query. :param hits_file_in: Name of input shape_screen_gpu hits file in Maestro or SD format :type hits_file_in: str :param diverse_fraction: Diverse fraction of hits to select :type diverse_fraction: float :param hits_file_out: Name of output hits file in Maestro or SD format :type hits_file_out: str :param fp_file_prefix: Prefix of temporary fingerprint files that will be generated from the hits for each shape query :type fp_file_prefix: str :param logger: Logger for info level progress messages :type logger: logging.Logger or NoneType :param progress_interval: Interval between progress messages for fingerprint generation :type progress_interval: int :return: Number of diverse hits written for each shape query :rtype: list[int] :raise: ValueError if diverse_fraction is outside the legal range """ if logger: logger.info(f'\nGenerating molprint2D fingerprints from {hits_file_in}') fp_tuples = generate_shape_gpu_fingerprints(hits_file_in, fp_file_prefix, logger, progress_interval) nqueries = len(fp_tuples) hit_positions = set() num_hits_per_query = [] for i, (fp_file, fp_count) in enumerate(fp_tuples, 1): ndiverse = int(round(diverse_fraction * fp_count)) num_hits_per_query.append(ndiverse) if logger: msg = (f'\nQuery {i} of {nqueries}: Selecting {ndiverse:,} ' f'diverse structures from a total of {fp_count:,} hits') logger.info(msg) fp_domains = split_hits(fp_file, diverse_fraction) nchunks = len(fp_domains) ndiverse_per_chunk = phase.partitionValues(ndiverse, nchunks) diverse_subset = [] for j, fp_domain in enumerate(fp_domains): ndiv_chunk = ndiverse_per_chunk[j] if logger: logger.info(f'Processing chunk {j + 1} of {nchunks}') msg = (f'Selecting {ndiv_chunk:,} diverse structures from ' f'{len(fp_domain):,} hits') logger.info(msg) t1 = time.process_time() selector = diversity_selector.DiversitySelector( fp_file, opt_cycles=MAX_OPT_CYCLES, fp_domain=fp_domain) selector.select(ndiv_chunk) if logger: t2 = time.process_time() msg = 'Average nearest-neighbor similarity: %.6f' logger.info(msg % selector._nn_sim_avg) logger.info('CPU time to select: %.2f sec' % (t2 - t1)) diverse_subset.extend(selector.subset_rows) hit_positions.update( get_shape_gpu_hits_positions(fp_file, diverse_subset)) nhits = len(hit_positions) if logger: query = 'queries' if len(fp_tuples) > 1 else 'query' msg = (f'\nWriting shape {query} and {nhits} diverse hits to ' f'{hits_file_out}') logger.info(msg) t1 = time.process_time() with structure.StructureReader(hits_file_in) as reader, \ structure.StructureWriter(hits_file_out) as writer: for i, st in enumerate(reader): if st.property.get(phase.PHASE_SHAPE_SIM, None) is None: # Shape query. writer.append(st) elif i in hit_positions: writer.append(st) if logger: t2 = time.process_time() logger.info('CPU time to write: %.2f sec' % (t2 - t1)) return num_hits_per_query
[docs]def split_hits(fp_file, diverse_fraction): """ Figuratively splits the hits in the supplied fingerprint file into approximately equal-sized chunks that occupy non-overlapping regions of fingerprint space. Returns lists of 0-based fingerprint row numbers that define the various chunks. :param fp_file: Name of fingerprint file of hits :type fp_file: str :param diverse_fraction: Diverse fraction of hits to select :type diverse_fraction: float :return: Lists of 0-based fingerprint row numbers of the chunks :rtype: list(list(int)) :raise: ValueError if diverse_fraction is outside the legal range """ min_pop = get_min_pop(diverse_fraction) num_hits = canvas.ChmFPIn32(fp_file).getRowCount() num_diverse = int(round(diverse_fraction * num_hits)) num_probes = get_num_probes(num_hits, min_pop) # Ensure that the population of each chunk is large enough to allow a # minimum number of diverse structures to be selected. min_pop = driver_utils.adjust_min_pop(min_pop, num_diverse, driver_utils.MIN_DIVERSE_PER_CHUNK, num_hits) # Perform the split. splitter = diversity_splitter.DiversitySplitter(fp_file, min_pop, num_probes) return splitter.getOrthantRows()