Raw File
cached_values.py
import numpy as np
import pandas as pd


class CachedValues():
  """Class for mapping keys to pre-computed values.

  Keys are hashable objects in numpy arrays, for example, SMILES strings.
  """

  def __init__(self, keys: np.ndarray, values: np.ndarray):
    if len(set(keys)) != len(keys):
      raise ValueError('All keys should be distinct; duplicates found!')
    if len(keys) != len(values):
      raise ValueError('Keys and values should have the same length,'
                       ' found len(keys) of {} and len(values) of {}'.format(
                           len(keys), len(values)))
    self.keys = keys
    self.values = values
    self.cache = pd.Series(data=np.arange(len(keys)), index=keys)

  @classmethod
  def load_from_npz(cls,
                    filepath,
                    key_name='smiles',
                    value_name='prediction'):
    data = np.load(filepath, allow_pickle=True)
    return cls(data[key_name], data[value_name])

  def __call__(self, key_array: np.ndarray):
    if not isinstance(key_array, np.ndarray):
      raise ValueError(f'key_array={key_array} should be a numpy array!')
    indices = self.cache[key_array].values
    if np.any(np.isnan(indices)):
      raise ValueError(
          f'key_array={key_array[np.isnan(indices)]} not in cache index!')
    return self.values[indices]

  def save_to_npz(self, filepath, key_name='smiles', value_name='prediction'):
    np.savez_compressed(
        filepath, **{key_name: self.keys, value_name: self.values})
back to top