# -*- coding: utf-8 -*-
import inspect
import sys
from typing import Any, Callable, Dict, overload, Type, Optional, Union

import typeguard

from outflow.core.target import Target, TargetException, NoDefault
from outflow.core.types import Skipped, Parameter
from typing_extensions import TypedDict

from outflow.core.exceptions import (
    TaskException,
    IOCheckerError,
    TaskWithKwargsException,
    ContextArgumentException,
)
from outflow.core.generic.string import to_camel_case
from outflow.core.pipeline import config
from outflow.core.logging import logger

from outflow.core.block import Block


class Task(Block):
    _parameters: Dict[str, Target] = None
    with_self: bool
    block_type: str = "task"

    @classmethod
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        #     reset the targets definition attribute to avoid sharing targets definition with subclasses
        cls._parameters = {}

        # if the return type is not defined, force it to Dict
        if cls.run.__annotations__ is None:
            cls.run.__annotations__ = dict()

        # defaults with_self to True for inherited classes (as_task decorator sets to false by default)
        if not hasattr(cls, "with_self"):
            cls.with_self = True

        # convert the run function to a static method if needed
        if not cls.with_self:
            cls.run = staticmethod(cls.run)

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self._parametrized_kwargs = None
        self.skip_if_upstream_skip = True

        self.parameters = self._parameters

        self.setup_targets()

    def __call__(self, *args, **kwargs):

        # Blocks are called with block_db in kwargs, so that workflows can pass it to the BlockRunner and the
        # BlockRunner sets the right parent block in database
        # this is useless for tasks, delete it so that it is not passed to the run method
        if "block_db" in kwargs:
            del kwargs["block_db"]

        logger.debug(f"Running task {self.name}")
        if args:
            raise Exception(
                "Task use keyword-only arguments but positional arguments were passed to the run function"
            )
        task_inputs = {**kwargs, **self.parameterized_kwargs}

        memo = self.generate_memo(task_inputs)
        self.check_argument_types(memo)

        return_value = self.run(**task_inputs)

        if not isinstance(return_value, dict):
            if isinstance(return_value, Skipped):
                # if the returned object is Skipped(), put skipped in all outputs
                return_value = {output: return_value for output in self.outputs}
            elif len(self.outputs) > 1:
                raise TargetException(
                    f"Task {self.name} must return a dictionary since it has defined more than one output target"
                )
            elif len(self.outputs) == 1:
                # If an object is returned and task defined only 1 output
                # -> put the object in a dictionary with the output name as the key
                return_value = {
                    next(iter(self.outputs)): return_value
                }  # get first and only item in dictionary task.outputs

        self.check_return_type(return_value)

        return return_value

    def generate_memo(self, task_kwargs: dict):

        # Find the frame in which the function was declared, for resolving forward references later
        _localns = sys._getframe(1).f_locals

        # stored needed info to check the task input/output types
        try:
            return typeguard._CallMemo(self.run, _localns, args=[], kwargs=task_kwargs)
        except TypeError as e:
            raise TypeError(f"In task '{self.name}' - {e}") from e

    def check_argument_types(self, memo):
        try:
            typeguard.check_argument_types(memo)
        except TypeError as e:
            raise TypeError(f"In task '{self.name}' - {e}") from e

    def check_return_type(self, return_value):
        """A custom version of typeguard.check_return_type(return_value, memo)

        This checks the type of each output (type being Any if declared with __auto__)

        Note: we have to skip both None and Skipped tasks to improve the TypeError
        generated by 'typeguard.check_type'

        """
        if return_value is not None and not isinstance(return_value, Skipped):
            # check either against a typed dict or a single type
            if isinstance(return_value, dict):
                typed_dict = {}
                for output_target in self.outputs.values():
                    typed_dict.update({output_target.name: output_target.type})

                return_type = TypedDict("return_type", typed_dict)
            else:
                return_type = next(iter(self.outputs.values())).type

            try:
                typeguard.check_type("return_value", return_value, return_type)
            except TypeError as e:
                raise TypeError(f"In task '{self.name}' - {e}") from e

    def setup_targets(self):
        self.setup_auto_inputs()
        self.setup_auto_outputs()
        self.setup_auto_parameters()

    def setup_auto_parameters(self):
        """
        Exists only to support legacy @target.Parameter decorator.
        """

        for parameter in self.parameters:
            if parameter in self.inputs:
                del self.inputs[parameter]

    def setup_auto_inputs(self):

        # get the names and default values of the run function parameters and return annotation
        full_args_spec = inspect.getfullargspec(self.run)

        if full_args_spec.kwonlyargs:
            raise TaskWithKwargsException(
                f"Tasks cannot have keyword only args (in task '{self.name}')"
            )

        # create input targets from the run function parameters
        defaults = []
        for i in range(1, len(full_args_spec.args) + 1):
            try:
                defaults.append(full_args_spec.defaults[-i])
            except (IndexError, TypeError):
                defaults.append(NoDefault)
        defaults.reverse()

        for index, input_arg_name in enumerate(full_args_spec.args):
            if self.with_self and index == 0:
                if input_arg_name != "self":
                    raise ContextArgumentException(
                        f"The task '{self}'' was declared with context but the first argument of the run method is not self"
                    )
                # skip the "self" argument for run class method
                continue

            annotations = full_args_spec.annotations
            if input_arg_name not in self.inputs:
                if input_arg_name in annotations:
                    input_type = annotations[input_arg_name]
                else:
                    input_type = Any

                is_parameter: bool = False

                try:
                    if input_type.__name__.startswith(Parameter.prefix):
                        is_parameter = True
                except AttributeError:
                    pass

                if is_parameter:
                    self.parameters.update(
                        {
                            input_arg_name: Target(
                                input_arg_name, type=input_type.__supertype__
                            )
                        }
                    )
                else:
                    self.inputs.update(
                        {
                            input_arg_name: Target(
                                input_arg_name, type=input_type, default=defaults[index]
                            )
                        }
                    )

    def setup_auto_outputs(self):

        import inspect

        # get the names and default values of the run function parameters and return annotation
        full_args_spec = inspect.getfullargspec(self.run)

        # create target outputs from return annotations
        return_annotation = full_args_spec.annotations.get("return", None)

        if return_annotation is not None:
            if isinstance(return_annotation, dict):
                # Check if return annotation is a simple dict
                type_dict_items = list(return_annotation.items())
            elif return_annotation.__class__.__name__ == "_TypedDictMeta":
                # Check if return annotation is a manually defined TypedDict
                type_dict_items = list(return_annotation.__annotations__.items())
            else:
                raise TargetException(
                    "The return annotation of a task must be a dict or a TypedDict"
                )

            for output_name, output_type in type_dict_items:
                self.outputs.update(
                    {output_name: Target(output_name, type=output_type)}
                )

        elif not self.outputs:
            try:
                # import these modules only if using auto output targets
                import ast
                import inspect
                import textwrap

                class Visitor(ast.NodeVisitor):
                    def visit_Return(visitor, node: ast.Return):
                        try:
                            for output in node.value.keys:
                                self.add_output(output.s)
                        except AttributeError:
                            raise Exception(
                                f"Outflow could not automatically determine the outputs of task {self.name}"
                            )

                Visitor().visit(ast.parse(textwrap.dedent(inspect.getsource(self.run))))
            except Exception:
                raise TargetException(
                    f"Could not automatically determine outputs of task {self}. Check that this tasks "
                    f"returns a dictionary. To disable automatic task outputs, call Task constructor with "
                    f"auto_outputs=False "
                )

    def add_parameter(self, target):
        if target.name in self.inputs:
            del self.inputs[target.name]
        self.parameters.update({target.name: target})

    @classmethod
    def as_task(
        cls,
        run_func: Callable = None,
        name: Optional[str] = None,
        with_self: bool = False,
        plugin_name: Optional[str] = None,
    ) -> Union[Type["Task"], Callable[[Callable], Type["Task"]]]:
        """
        Transform the decorated function into a outflow task (class).

        The definition of the 'as_task' decorator is done outside of the class to ensure proper function typing with overloads
        """
        return as_task(run_func, cls, name, with_self, plugin_name)

    @property
    def parameterized_kwargs(self):
        """Generate the parameters kwargs dict from the config file content

        'parameterized_kwargs' refers to parameters used as task arguments and declared in the configuration file.

        Raises:
            TaskException: Raise an exception if task parameters are set but no parameters configuration is found

        Returns:
            dict: kwargs dict generated from the config file content
        """
        if self._parametrized_kwargs is not None:
            return self._parametrized_kwargs
        else:
            kwargs = {}
            if self.parameters:
                for parameter in self.parameters:
                    try:
                        kwargs.update(
                            {parameter: config["parameters"][self.name][parameter]}
                        )
                    except KeyError as err:
                        raise TaskException(
                            f"Could not find parameter {parameter} for task {self.name} in configuration file"
                        ) from err
                for config_param in config["parameters"][self.name]:
                    if config_param not in kwargs:
                        logger.warning(
                            f"Task parameter {config_param} defined in configuration file but not retrieved by task {self.name}."
                        )
            self._parametrized_kwargs = kwargs
            return kwargs

    def check_inputs(self, task_inputs: dict, values_from_upstream: dict):
        for target_name, target in self.inputs.items():
            if (
                target_name not in task_inputs
                and self.inputs[target_name].default == NoDefault
                and target_name not in self.bound_kwargs
            ):
                got = [tname for tname in values_from_upstream.keys()]

                for tname in task_inputs.keys():
                    got.append(tname)

                for tname in self.bound_kwargs.keys():
                    got.append(tname)

                for input_target in self.inputs.values():
                    if input_target.default != NoDefault:
                        got.append(input_target.name)

                raise IOCheckerError(
                    f"Task {self.name} did not get all expected inputs: expected {[k for k in self.inputs.keys()]}, got "
                    f"{list(set(got))}"
                )

    def run(self, *args, **kwargs):
        raise NotImplementedError()


@overload
def as_task(
    run_func: Callable,
    cls: Type[Task] = Task,
    name: Optional[str] = None,
    with_self: bool = False,
    plugin_name: Optional[str] = None,
) -> Type[Task]:
    pass


@overload
def as_task(
    cls: Type[Task] = Task,
    name: Optional[str] = None,
    with_self: bool = False,
    plugin_name: Optional[str] = None,
) -> Callable[[Callable], Type[Task]]:
    pass


def as_task(
    run_func: Callable = None,
    cls: Type[Task] = Task,
    name: Optional[str] = None,
    with_self: bool = False,
    plugin_name: Optional[str] = None,
) -> Union[Type[Task], Callable[[Callable], Type[Task]]]:
    """
    Transform the decorated function into a outflow task (class).

    Args:
        run_func: The decorated function.
        cls: class of the Task to which the decorated function will be transformed
        name (str): Name of the task. By default, this is the name of the
            function in snake case.
        with_self (bool): If true, the run function will be called as a
            regular method, so the "self" of the task is available in the
            task code.
        plugin_name (str): Optional, overrides automatic plugin name detection

    Returns:
        A task class that run the decorated function
    """

    if run_func is None:

        def inner_function(_run_func):
            return as_task(
                _run_func,
                cls,
                name=name,
                with_self=with_self,
                plugin_name=plugin_name,
            )

        return inner_function
    else:
        if name is None:
            name = run_func.__name__

        task_class = type(
            to_camel_case(name),
            (cls,),
            {
                "run": run_func,
                "with_self": with_self,
                "plugin_name": plugin_name,
            },
        )

        return task_class
