from abc import ABC, abstractmethod
from typing import Tuple

import numpy as np
from rdkit.Chem import Mol

from deepmol.datasets import Dataset
from deepmol.loggers.logger import Logger
from deepmol.parallelism.multiprocessing import JoblibMultiprocessing
from deepmol.utils.utils import canonicalize_mol_object, mol_to_smiles


class MolecularStandardizer(ABC):
    """
    Class for handling the standardization of molecules.
    """

    def __init__(self, n_jobs: int = -1) -> None:
        """
        Standardizer for molecules.

        Parameters
        ----------
        n_jobs: int
            Number of jobs to run in parallel.
        """
        self.n_jobs = n_jobs
        self.logger = Logger()
        self.logger.info(f"Standardizer {self.__class__.__name__} initialized with {n_jobs} jobs.")

    def _standardize_mol(self, mol: Mol) -> Tuple[Mol, str]:
        """
        Standardizes a single molecule.

        Parameters
        ----------
        mol: Mol
            Molecule to standardize.

        Returns
        -------
        mol: Mol
            Standardized Mol object.
        smiles: str
            Standardized SMILES string.
        """
        try:
            mol_object = mol
            assert mol_object is not None
            mol_object = canonicalize_mol_object(mol_object)
            standardized_mol = self._standardize(mol_object)
            return standardized_mol, mol_to_smiles(standardized_mol, canonical=True)
        except Exception:
            return mol, mol_to_smiles(mol, canonical=True)

    def standardize(self, dataset: Dataset) -> Dataset:
        """
        Standardizes a dataset of molecules.

        Parameters
        ----------
        dataset: Dataset
            Dataset to standardize.

        Returns
        -------
        dataset: Dataset
            Standardized dataset.
        """
        molecules = dataset.mols
        multiprocessing_cls = JoblibMultiprocessing(n_jobs=self.n_jobs, process=self._standardize_mol)
        result = list(multiprocessing_cls.run(molecules))
        dataset._smiles = np.asarray([x[1] for x in result])
        dataset._mols = np.asarray([x[0] for x in result])
        return dataset

    @abstractmethod
    def _standardize(self, mol: Mol) -> Mol:
        """
        Standardizes a molecule.

        Parameters
        ----------
        mol: Mol
            RDKit Mol object

        Returns
        -------
        mol: Mol
            Standardized mol.
        """
