from typing import Optional
from classy_imaginary import config
from classy_imaginary.api import _generate_single_image
from classy_imaginary import ImaginePrompt, ImagineResult
from classy_imaginary.enhancers.face_restoration_codeformer import codeformer_model
from classy_imaginary.model_manager import get_diffusion_model
from classy_imaginary.modules.midas.utils import AddMiDaS
from classy_imaginary.safety import EnhancedStableDiffusionSafetyChecker

import logging
from transformers import logging as transformers_logging, AutoFeatureExtractor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
transformers_logging.set_verbosity_error()


class Imagine:
    def __init__(self,
                 model_name: str = config.DEFAULT_MODEL,
                 half_mode: bool = None,
                 precision="autocast",
                 for_inpainting: bool = False):
        """
        Initialize the Imagine class.
        :param model_name: the name of the SD model to use.
        :param half_mode: whether to use half-precision. If None, will use half-precision if available.
        :param precision: whether to use autocast or not.
        :param for_inpainting: whether to use the model for inpainting.
        """
        self.model_name = model_name
        self.half_mode = half_mode
        self.precision = precision
        model_config = config.get_model_config(model_name)

        if model_config is None:
            raise ValueError(f"Unknown model name: {model_name}")

        self.sd_model = get_diffusion_model(
            weights_location=model_config.weights_url,
            config_path=model_config.config_path,
            half_mode=half_mode,
            for_inpainting=for_inpainting,
        )

        if for_inpainting:
            self.midas_model = AddMiDaS()
        else:
            self.midas_model = None

        safety_model_id = "CompVis/stable-diffusion-safety-checker"
        self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
        self.safety_checker = EnhancedStableDiffusionSafetyChecker.from_pretrained(
            safety_model_id
        )
        self.code_former_model = codeformer_model()

    def imagine(self,
                prompt: ImaginePrompt,
                nsfw_filter: bool = False,
                debug_img_callback=None,
                progress_img_callback=None,
                progress_img_interval_steps=3,
                progress_img_interval_min_s=0.1) -> Optional[ImagineResult]:
        """
        Run inference on the model.
        :param prompt: ImaginePrompt.
        :param nsfw_filter: whether to filter out NSFW images.
        :param debug_img_callback: a callback that will be called with the debug image.
        :param progress_img_callback: a callback that will be called with the progress image.
        :param progress_img_interval_steps: the number of steps between progress images.
        :param progress_img_interval_min_s: the minimum time between progress images.
        :return: a list of ImagineResult objects.
        """

        result = _generate_single_image(
            prompt,
            self.sd_model,
            midas_model=self.midas_model,
            nsfw_filter=nsfw_filter,
            safety_feature_extractor=self.safety_feature_extractor,
            safety_checker=self.safety_checker,
            codformer_model=self.code_former_model,
            debug_img_callback=debug_img_callback,
            progress_img_callback=progress_img_callback,
            progress_img_interval_steps=progress_img_interval_steps,
            progress_img_interval_min_s=progress_img_interval_min_s,
            half_mode=self.half_mode,
            add_caption=False,
        )

        return result
