#
#  Copyright (c) 2003-2021 Rational Discovery LLC
#
#  @@ All Rights Reserved @@
#  The contents are covered by the terms of the BSD license
#  which is included in the file license.txt, found at the root
#  of the RDKit source tree.
#
""" utility functionality for fingerprinting sets of molecules
 includes a command line app for working with fingerprints
 and databases


Sample Usage:

  python FingerprintMols.py  -d data.gdb \
        -t 'raw_dop_data' --smilesName="Structure" --idName="Mol_ID"  \
        --outTable="daylight_sig"

"""

import getopt
import sys,logging,tqdm

from rdkit import Chem
from rdkit import DataStructs
from rdkit.Chem import MACCSkeys
from rdkit.ML.Cluster import Murtagh
import pickle


def GetRDKFingerprint(mol):
  """ uses default parameters """
  details = FingerprinterDetails()
  return FingerprintMol(mol, **details.__dict__)


def FoldFingerprintToTargetDensity(fp, **fpArgs):
  nOn = fp.GetNumOnBits()
  nTot = fp.GetNumBits()
  while (float(nOn) / nTot < fpArgs['tgtDensity']):
    if nTot / 2 > fpArgs['minSize']:
      fp = DataStructs.FoldFingerprint(fp, 2)
      nOn = fp.GetNumOnBits()
      nTot = fp.GetNumBits()
    else:
      break
  return fp


def FingerprintMol(mol, fingerprinter, **fpArgs):
  if not fpArgs:
    details = FingerprinterDetails()
    fpArgs = details.__dict__

  if fingerprinter == "RDKIT":
    fp = Chem.RDKFingerprint(mol, fpArgs['minPath'], fpArgs['maxPath'], fpArgs['fpSize'], fpArgs['bitsPerHash'], fpArgs['useHs'], fpArgs['tgtDensity'], fpArgs['minSize'])
  elif fingerprinter == "MACCS":
    fp = Chem.MACCSkeys.GenMACCSKeys(mol)
  elif fingerprinter == "MORGAN":
    fp = Chem.AllChem.GetMorganFingerprintAsBitVect(mol, fpArgs["morgan_radius"], fpArgs["morgan_nbits"])
  else: #RDKIT
    fp = Chem.RDKFingerprint(mol, fpArgs['minPath'], fpArgs['maxPath'], fpArgs['fpSize'], fpArgs['bitsPerHash'], fpArgs['useHs'], fpArgs['tgtDensity'], fpArgs['minSize'])
  logging.debug(f"{fp.ToBitString()} ({fp.GetNumBits()} bits)")
  return fp

def FingerprintsFromSmiles(dataSource, idCol, smiCol, fingerprinter, reportFreq=10, maxMols=-1, **fpArgs):
  """ fpArgs are passed as keyword arguments to the fingerprinter

  Returns a list of 2-tuples: (ID,fp)

  """
  res=[]; nDone=0; tq=None;
  for entry in dataSource:
    ID, smi = str(entry[idCol]), str(entry[smiCol])
    mol = Chem.MolFromSmiles(smi)
    if mol is not None:
      fp = FingerprintMol(mol, fingerprinter, **fpArgs)
      res.append((ID, fp))
      if not tq: tq = tqdm.tqdm(total=len(dataSource), unit="molecules")
      tq.update()
      nDone+=1
      if maxMols>0 and nDone>=maxMols:
        break
    else:
      logging.error(f"Problems parsing SMILES: {smi}")
  if tq is not None: tq.close()
  return res


def FingerprintsFromMols(mols, fingerprinter, reportFreq=10, maxMols=-1, **fpArgs):
  """ fpArgs are passed as keyword arguments to the fingerprinter

  Returns a list of 2-tuples: (ID,fp)
  """
  res=[]; nDone=0; tq=None;
  for ID, mol in mols:
    if mol:
      fp = FingerprintMol(mol, fingerprinter, **fpArgs)
      res.append((ID, fp))
      if not tq: tq = tqdm.tqdm(total=len(mols), unit="molecules")
      tq.update()
      nDone+=1
      if maxMols>0 and nDone>=maxMols:
        break
    else:
      logging.error(f"Problems parsing SMILES: {smi}")
  if tq is not None: tq.close()
  return res


def FingerprintsFromPickles(dataSource, idCol, pklCol, fingerprinter, reportFreq=10, maxMols=-1, **fpArgs):
  """ fpArgs are passed as keyword arguments to the fingerprinter

  Returns a list of 2-tuples: (ID,fp)

  """
  res=[]; nDone=0; tq=None;
  for entry in dataSource:
    ID, pkl = str(entry[idCol]), str(entry[pklCol])
    mol = Chem.Mol(pkl)
    if mol is not None:
      fp = FingerprintMol(mol, fingerprinter, **fpArgs)
      res.append((ID, fp))
      if not tq: tq = tqdm.tqdm(total=len(dataSource), unit="molecules")
      tq.update()
      nDone += 1
      if maxMols>0 and nDone>=maxMols:
        break
    else:
      logging.error(f"Problems parsing pickle for ID: {ID}")
  if tq is not None: tq.close()
  return res


def FingerprintsFromDetails(details, reportFreq=10):
  data = None
  if details.dbName and details.tableName:
    from rdkit.Dbase.DbConnection import DbConnect
    from rdkit.Dbase import DbInfo
    from rdkit.ML.Data import DataUtils
    try:
      conn = DbConnect(details.dbName, details.tableName)
    except Exception:
      import traceback
      logging.error(f"Problems establishing connection to database: {details.dbName}|{details.tableName}")
      traceback.print_exc()
    if not details.idName:
      details.idName = DbInfo.GetColumnNames(details.dbName, details.tableName)[0]
    dataSet = DataUtils.DBToData(details.dbName, details.tableName, what=f"{details.idName},{details.smilesName}")
    idCol = 0
    smiCol = 1
  elif details.inFileName and details.useSmiles:
    from rdkit.ML.Data import DataUtils
    conn = None
    if not details.idName:
      details.idName = 'ID'
    try:
      dataSet = DataUtils.TextFileToData(details.inFileName,
                                         onlyCols=[details.idName, details.smilesName])
    except IOError:
      import traceback
      logging.error(f"Problems reading from file {details.inFileName}")
      traceback.print_exc()

    idCol = 0
    smiCol = 1
  elif details.inFileName and details.useSD:
    conn = None
    dataset = None
    if not details.idName:
      details.idName = 'ID'
    dataSet = []
#    try:
#      s = Chem.SDMolSupplier(details.inFileName)
#    except Exception:
#      import traceback
#      logging.error(f"Problems reading from file {details.inFileName}")
#      traceback.print_exc()
#    else:
#      while 1:
#        try:
#          m = s.next()
#        except StopIteration:
#          break
#        if m:
#          dataSet.append(m)
#          if reportFreq>0 and len(dataSet)%reportFreq==0:
#            logging.info(f"Read {len(dataSet)} molecules")
#            if details.maxMols>0 and len(dataSet)>=details.maxMols:
#              break

    molReader = Chem.SDMolSupplier(details.inFileName)
    for mol in molReader:
      dataSet.append(mol)

    for i, mol in enumerate(dataSet):
      if mol.HasProp(details.idName):
        nm = mol.GetProp(details.idName)
      else:
        nm = mol.GetProp('_Name')
      dataSet[i] = (nm, mol)
  else:
    dataSet = None

  fps = None
  if dataSet and not details.useSD:
    data = dataSet.GetNamedData()
    if not details.molPklName:
      fps = FingerprintsFromSmiles(data, idCol, smiCol, reportFreq=reportFreq, **details.__dict__)
    else:
      fps = FingerprintsFromPickles(data, idCol, smiCol, reportFreq=reportFreq, **details.__dict__)
  elif dataSet and details.useSD:
    fps = FingerprintsFromMols(dataSet, reportFreq=reportFreq, **details.__dict__)

  if fps:
    if details.outFileName:
      logging.info(f"Writing pickled FPs to {details.outFileName}")
      outF = open(details.outFileName, 'wb+')
      for i in range(len(fps)):
        pickle.dump(fps[i], outF)
      outF.close()
    dbName = details.outDbName or details.dbName
    if details.outTableName and dbName:
      from rdkit.Dbase.DbConnection import DbConnect
      from rdkit.Dbase import DbUtils, DbModule
      conn = DbConnect(dbName)
      #
      #  We don't have a db open already, so we'll need to figure out
      #    the types of our columns...
      #
      colTypes = DbUtils.TypeFinder(data, len(data), len(data[0]))
      typeStrs = DbUtils.GetTypeStrings([details.idName, details.smilesName], colTypes,
                                        keyCol=details.idName)
      cols = f"{typeStrs[0]}, {details.fpColName} {DbModule.binaryTypeName}"

      # FIX: we should really check to see if the table
      #  is already there and, if so, add the appropriate
      #  column.

      #
      # create the new table
      #
      if details.replaceTable or \
         details.outTableName.upper() not in [x.upper() for x in conn.GetTableNames()]:
        conn.AddTable(details.outTableName, cols)

      #
      # And add the data
      #
      for ID, fp in fps:
        tpl = ID, DbModule.binaryHolder(fp.ToBinary())
        conn.InsertData(details.outTableName, tpl)
      conn.Commit()
  return fps
# ------------------------------------------------
#
#  Command line parsing stuff
#
# ------------------------------------------------


class FingerprinterDetails(object):
  """ class for storing the details of a fingerprinting run,
     generates sensible defaults on construction

  """

  def __init__(self):
    self._fingerprinterInit()
    self._screenerInit()
    self._clusterInit()

  def _fingerprinterInit(self):
    self.fingerprinter = "RDKIT"
    self.fpColName = "AutoFragmentFP"
    self.idName = ''
    self.dbName = ''
    self.outDbName = ''
    self.tableName = ''
    self.minSize = 64
    self.fpSize = 2048
    self.tgtDensity = 0.3
    self.minPath = 1
    self.maxPath = 7
    self.discrimHash = 0
    self.useHs = 0
    self.useValence = 0
    self.bitsPerHash = 2
    self.smilesName = 'SMILES'
    self.maxMols = -1
    self.outFileName = ''
    self.outTableName = ''
    self.inFileName = ''
    self.iheader = False
    self.replaceTable = True
    self.molPklName = ''
    self.useSmiles = True
    self.useSD = False
    self.morgan_radius = 2
    self.morgan_nbits = 1024

  def _screenerInit(self):
    self.metric = DataStructs.TanimotoSimilarity
    self.doScreen = ''
    self.topN = 10
    self.screenThresh = 0.75
    self.doThreshold = 0
    self.smilesTableName = ''
    self.probeSmiles = ''
    self.probeMol = None
    self.noPickle = 0

  def _clusterInit(self):
    self.clusterAlgo = Murtagh.WARDS
    self.actTableName = ''
    self.actName = ''

  def GetMetricName(self):
    if self.metric == DataStructs.TanimotoSimilarity:
      return 'Tanimoto'
    elif self.metric == DataStructs.DiceSimilarity:
      return 'Dice'
    elif self.metric == DataStructs.CosineSimilarity:
      return 'Cosine'
    elif self.metric:
      return self.metric
    else:
      return 'Unknown'

  def SetMetricFromName(self, name):
    name = name.upper()
    if name == "TANIMOTO":
      self.metric = DataStructs.TanimotoSimilarity
    elif name == "DICE":
      self.metric = DataStructs.DiceSimilarity
    elif name == "COSINE":
      self.metric = DataStructs.CosineSimilarity


def Usage():
  """  prints a usage string and exits

  """
  print(_usageDoc)
  sys.exit(-1)


_usageDoc = """
Usage: FingerprintMols.py [args] <fName>

  If <fName> is provided and no tableName is specified (see below),
  data will be read from the text file <fName>.  Text files delimited
  with either commas (extension .csv) or tabs (extension .txt) are
  supported.

  Command line arguments are:
    - -d _dbName_: set the name of the database from which
      to pull input molecule information.  If output is
      going to a database, this will also be used for that
      unless the --outDbName option is used.

    - -t _tableName_: set the name of the database table
      from which to pull input molecule information

    - --smilesName=val: sets the name of the SMILES column
      in the input database.  Default is *SMILES*.

    - --useSD:  Assume that the input file is an SD file, not a SMILES
       table.

    - --idName=val: sets the name of the id column in the input
      database.  Defaults to be the name of the first db column
      (or *ID* for text files).

    - -o _outFileName_:  name of the output file (output will
      be a pickle file with one label,fingerprint entry for each
      molecule).

    - --outTable=val: name of the output db table used to store
      fingerprints.  If this table already exists, it will be
      replaced.

    - --outDbName: name of output database, if it's being used.
      Defaults to be the same as the input db.

    - --fpColName=val: name to use for the column which stores
      fingerprints (in pickled format) in the output db table.
      Default is *AutoFragmentFP*

    - --maxSize=val:  base size of the fingerprints to be generated
      Default is *2048*

    - --minSize=val: minimum size of the fingerprints to be generated
      (limits the amount of folding that happens).  Default is *64*

    - --density=val: target bit density in the fingerprint.  The
      fingerprint will be folded until this density is
      reached. Default is *0.3*

    - --minPath=val:  minimum path length to be included in
      fragment-based fingerprints. Default is *1*.

    - --maxPath=val:  maximum path length to be included in
      fragment-based fingerprints. Default is *7*.

    - --nBitsPerHash: number of bits to be set in the output
      fingerprint for each fragment. Default is *2*.

    - --discrim: use of path-based discriminators to hash bits.
      Default is *false*.

    - -V: include valence information in the fingerprints
      Default is *false*.

    - -H: include Hs in the fingerprint
      Default is *false*.

    - --maxMols=val: sets the maximum number of molecules to be
      fingerprinted.

    - --useMACCS: use the public MACCS keys to do the fingerprinting
      (instead of a daylight-type fingerprint)

"""


def ParseArgs(details=None):
  """ parses the command line arguments and returns a
   _FingerprinterDetails_ instance with the results.

   **Note**:

     - If you make modifications here, please update the global
       _usageDoc string so the Usage message is up to date.

     - This routine is used by both the fingerprinter, the clusterer and the
       screener; not all arguments make sense for all applications.

  """
  args = sys.argv[1:]
  try:
    args, extras = getopt.getopt(args,
                                 'HVs:d:t:o:h',
                                 [
                                   'minSize=',
                                   'maxSize=',
                                   'density=',
                                   'minPath=',
                                   'maxPath=',
                                   'bitsPerHash=',
                                   'smilesName=',
                                   'molPkl=',
                                   'useSD',
                                   'idName=',
                                   'discrim',
                                   'outTable=',
                                   'outDbName=',
                                   'fpColName=',
                                   'maxMols=',
                                   'useMACCS',
                                   'keepTable',
                                   # SCREENING:
                                   'smilesTable=',
                                   'doScreen=',
                                   'topN=',
                                   'thresh=',
                                   'smiles=',
                                   'dice',
                                   'cosine',
                                   # CLUSTERING:
                                   'actTable=',
                                   'actName=',
                                   'SLINK',
                                   'CLINK',
                                   'UPGMA',
                                 ])
  except Exception:
    import traceback
    traceback.print_exc()
    Usage()

  if details is None:
    details = FingerprinterDetails()
  if len(extras):
    details.inFileName = extras[0]

  for arg, val in args:
    if arg == '-H':
      details.useHs = 1
    elif arg == '-V':
      details.useValence = 1
    elif arg == '-d':
      details.dbName = val
    elif arg == '-t':
      details.tableName = val
    elif arg == '-o':
      details.outFileName = val
    elif arg == '--minSize':
      details.minSize = int(val)
    elif arg == '--maxSize':
      details.fpSize = int(val)
    elif arg == '--density':
      details.tgtDensity = float(val)
    elif arg == '--outTable':
      details.outTableName = val
    elif arg == '--outDbName':
      details.outDbName = val
    elif arg == '--fpColName':
      details.fpColName = val
    elif arg == '--minPath':
      details.minPath = int(val)
    elif arg == '--maxPath':
      details.maxPath = int(val)
    elif arg == '--nBitsPerHash':
      details.bitsPerHash = int(val)
    elif arg == '--discrim':
      details.discrimHash = 1
    elif arg == '--smilesName':
      details.smilesName = val
    elif arg == '--molPkl':
      details.molPklName = val
    elif arg == '--useSD':
      details.useSmiles = False
      details.useSD = True
    elif arg == '--idName':
      details.idName = val
    elif arg == '--maxMols':
      details.maxMols = int(val)
    elif arg == '--useMACCS':
      details.fingerprinter = "MACCS"
    elif arg == '--keepTable':
      details.replaceTable = False

    # SCREENER:
    elif arg == '--smilesTable':
      details.smilesTableName = val
    elif arg == '--topN':
      details.doThreshold = 0
      details.topN = int(val)
    elif arg == '--thresh':
      details.doThreshold = 1
      details.screenThresh = float(val)
    elif arg == '--smiles':
      details.probeSmiles = val
    elif arg == '--dice':
      details.metric = DataStructs.DiceSimilarity
    elif arg == '--cosine':
      details.metric = DataStructs.CosineSimilarity

    # CLUSTERS:
    elif arg == '--SLINK':
      details.clusterAlgo = Murtagh.SLINK
    elif arg == '--CLINK':
      details.clusterAlgo = Murtagh.CLINK
    elif arg == '--UPGMA':
      details.clusterAlgo = Murtagh.UPGMA
    elif arg == '--actTable':
      details.actTableName = val
    elif arg == '--actName':
      details.actName = val
    elif arg == '-h':
      Usage()
  return details


if __name__ == '__main__':
  logging.info("This is FingerprintMols")
  details = ParseArgs()
  FingerprintsFromDetails(details)
