from .tasks import BaseTaskBackend, BaseStepTaskManager
from .pipelines import Pipeline
from .loggs import FileFormatter
from pathlib import Path
from traceback import format_exc as format_traceback_exc
import logging
import coloredlogs
from logging import getLogger
from functools import wraps
from platform import node
from pandas import Series

from typing import TYPE_CHECKING, List

if TYPE_CHECKING:
    from celery import Celery
    from .steps import BaseStep


APPLICATIONS_STORE = {}


def get_runner(task_name: str):
    from celery import Task

    class CeleryRunner(Task):
        name = task_name

        def run(self, task_id, extra=None):

            task = CeleryTaskRecord(task_id)

            try:
                session = task.get_session()
                application = task.get_application()
                arguments = task.arguments

                with LogTask(task) as log_object:
                    logger = log_object.logger
                    task["log"] = log_object.filename
                    task["status"] = "Started"
                    task.partial_update()

                    try:
                        step: "BaseStep" = (
                            application.pipelines[task.pipeline_name].pipes[task.pipe_name].steps[task.step_name]
                        )
                        if arguments.get("refresh", False) or arguments.get("refresh_requirements", []):
                            skip = False
                        else:
                            skip = True
                        arguments.pop(skip)

                        step.generate(session, extra=extra, skip=skip, check_requirements=True, **task.arguments)
                        task.status_from_logs(log_object)
                    except Exception as e:
                        traceback_msg = format_traceback_exc()
                        logger.critical(f"Fatal Error : {e}")
                        logger.critical("Traceback :\n" + traceback_msg)
                        task["status"] = "Failed"

            except Exception as e:
                # if it fails outside of the nested try statement, we can't store logs files,
                # and we mention the failure through alyx directly.
                task["status"] = "Uncatched_Fail"
                task["log"] = str(e)

            task.partial_update()

    return CeleryRunner


class CeleryAlyxTaskManager(BaseStepTaskManager):

    backend: "CeleryTaskBackend"
    step: "BaseStep"

    def register_step(self):
        if self.backend:
            # self.backend.app.task(CeleryRunner, name=self.step.complete_name)
            self.backend.app.register_task(get_runner(self.step.complete_name))

    def start(self, session, extra=None, **kwargs):

        if not self.backend:
            raise NotImplementedError(
                "Cannot start a task on a celery cluster as this pipeline " "doesn't have a working celery backend"
            )

        return CeleryTaskRecord.create(self, session, extra, **kwargs)


class CeleryTaskRecord(dict):
    session: Series

    # a class to make dictionnary keys accessible with attribute syntax
    def __init__(self, task_id, task_infos_dict={}, response_handle=None, session=None):

        if not task_infos_dict:
            from one import ONE

            connector = ONE(mode="remote", data_access_mode="remote")
            task_infos_dict = connector.alyx.rest("tasks", "read", id=task_id)

        super().__init__(task_infos_dict)
        self.session = session  # type: ignore
        self.response = response_handle

    def status_from_logs(self, log_object):
        with open(log_object.fullpath, "r") as f:
            content = f.read()

        if len(content) == 0:
            status = "No_Info"
        elif "CRITICAL" in content:
            status = "Failed"
        elif "ERROR" in content:
            status = "Errors"
        elif "WARNING" in content:
            status = "Warnings"
        else:
            status = "Complete"

        self["status"] = status

    def partial_update(self):
        from one import ONE

        connector = ONE(mode="remote", data_access_mode="remote")
        connector.alyx.rest("tasks", "partial_update", **self.export())

    def get_session(self):
        if self.session is None:
            from one import ONE

            connector = ONE(mode="remote", data_access_mode="remote")
            session = connector.search(id=self["session"], no_cache=True, details=True)
            self.session = session  # type: ignore

        return self.session

    def get_application(self):
        try:
            return APPLICATIONS_STORE[self["executable"]]
        except KeyError:
            raise KeyError(f"Unable to retrieve the application {self['executable']}")

    @property
    def pipeline_name(self):
        return self["name"].split(".")[0]

    @property
    def pipe_name(self):
        return self["name"].split(".")[1]

    @property
    def step_name(self):
        return self["name"].split(".")[2]

    @property
    def arguments(self):
        args = self.get("arguments", {})
        return args if args else {}

    @property
    def session_path(self) -> str:
        return self.session["path"]

    @property
    def task_id(self):
        return self["id"]

    def export(self):
        return {"id": self["id"], "data": {k: v for k, v in self.items() if k not in ["id", "session_path"]}}

    @staticmethod
    def create(task_manager: CeleryAlyxTaskManager, session, extra=None, **kwargs):
        from one import ONE

        connector = ONE(mode="remote", data_access_mode="remote")

        data = {
            "session": session.name,
            "name": task_manager.step.complete_name,
            "arguments": kwargs,
            "status": "Waiting",
            "executable": str(task_manager.backend.app.main),
        }

        task_dict = connector.alyx.rest("tasks", "create", data=data)

        worker = task_manager.backend.app.tasks[task_manager.step.complete_name]
        response_handle = worker.delay(task_dict["id"], extra=extra)

        return CeleryTaskRecord(
            task_dict["id"], task_infos_dict=task_dict, response_handle=response_handle, session=session
        )


class CeleryTaskBackend(BaseTaskBackend):
    app: "Celery"
    task_manager_class = CeleryAlyxTaskManager

    def __init__(self, parent: Pipeline, app: "Celery | None" = None):
        super().__init__(parent)
        self.parent = parent

        if app is not None:
            self.success = True
            self.app = app

            pipelines = getattr(self.app, "pipelines", {})
            pipelines[parent.pipeline_name] = parent
            self.app.pipelines = pipelines

    def start(self):
        self.app.start()

    def create_task_manager(self, step):
        task_manager = self.task_manager_class(step, self)
        task_manager.register_step()
        return task_manager


class CeleryPipeline(Pipeline):
    runner_backend_class = CeleryTaskBackend


def get_setting_files_path(conf_path, app_name) -> List[Path]:
    conf_path = Path(conf_path)
    if conf_path.is_file():
        conf_path = conf_path.parent
    files = []
    for prefix, suffix in zip(["", "."], ["", "_secrets"]):
        file_loc = conf_path / f"{prefix}celery_{app_name}{suffix}.toml"
        if file_loc.is_file():
            files.append(file_loc)
    return files


class LogTask:
    def __init__(self, task_record: CeleryTaskRecord, username=None, level="LOAD"):
        self.path = Path(task_record.session_path) / "logs"
        self.username = username if username is not None else (node() if node() else "unknown")
        self.worker_pk = task_record.task_id
        self.task_name = task_record["name"]
        self.level = getattr(logging, level.upper())

    def __enter__(self):
        self.path.mkdir(exist_ok=True)
        self.logger = getLogger()
        self.set_handler()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.remove_handler()

    def set_handler(self):
        self.filename = f"task_log.{self.task_name}.{self.worker_pk}.log"
        self.fullpath = self.path / self.filename
        fh = logging.FileHandler(self.fullpath)
        f_formater = FileFormatter()
        coloredlogs.HostNameFilter.install(
            fmt=f_formater.FORMAT,
            handler=fh,
            style=f_formater.STYLE,
            use_chroot=True,
        )
        coloredlogs.ProgramNameFilter.install(
            fmt=f_formater.FORMAT,
            handler=fh,
            programname=self.task_name,
            style=f_formater.STYLE,
        )
        coloredlogs.UserNameFilter.install(
            fmt=f_formater.FORMAT,
            handler=fh,
            username=self.username,
            style=f_formater.STYLE,
        )

        fh.setLevel(self.level)
        fh.setFormatter(f_formater)
        self.logger.addHandler(fh)

    def remove_handler(self):
        self.logger.removeHandler(self.logger.handlers[-1])


def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | None":

    failure_message = (
        f"Celery app : {app_name} failed to be created."
        "Don't worry, about this alert, "
        "this is not be an issue if you didn't explicitely planned on using celery. Issue was : "
    )

    logger = getLogger("pypelines.create_celery_app")

    if app_name in APPLICATIONS_STORE.keys():
        logger.warning(f"Tried to create a celery app named {app_name}, but it already exists. Returning it instead.")
        return APPLICATIONS_STORE[app_name]

    settings_files = get_setting_files_path(conf_path, app_name)

    if len(settings_files) == 0:
        logger.warning(f"{failure_message} Could not find celery toml config files.")
        return None

    try:
        from dynaconf import Dynaconf
    except ImportError:
        logger.warning(f"{failure_message} Could not import dynaconf. Maybe it is not istalled ?")
        return None

    try:
        settings = Dynaconf(settings_files=settings_files)
    except Exception as e:
        logger.warning(f"{failure_message} Could not create dynaconf object. {e}")
        return None

    try:
        app_display_name = settings.get("app_display_name", app_name)
        broker_type = settings.connexion.broker_type
        account = settings.account
        password = settings.password
        address = settings.address
        backend = settings.connexion.backend
        conf_data = settings.conf
        v_host = settings.broker_conf.virtual_host if v_host is None else v_host
    except (AttributeError, KeyError) as e:
        logger.warning(f"{failure_message} {e}")
        return None

    try:
        from celery import Celery
    except ImportError:
        logger.warning(f"{failure_message} Could not import celery. Maybe is is not installed ?")
        return None

    try:
        app = Celery(
            app_display_name,
            broker=f"{broker_type}://{account}:{password}@{address}/{v_host}",
            backend=f"{backend}://",
        )
    except Exception as e:
        logger.warning(f"{failure_message} Could not create app. Maybe rabbitmq server @{address} is not running ? {e}")
        return None

    for key, value in conf_data.items():
        try:
            setattr(app.conf, key, value)
        except Exception as e:
            logger.warning(f"{failure_message} Could assign extra attribute {key} to celery app. {e}")
            return None

    APPLICATIONS_STORE[app_name] = app

    from celery import Task

    class handshake(Task):
        name = "handshake"

        def run(self):
            return f"{node()} is happy to shake your hand and says hello !"

    app.register_task(handshake)

    return app
