"""基础的消息处理器, 包括 DetectPrefix 与 DetectSuffix"""
import abc
import difflib
import fnmatch
import re
import weakref
from collections import defaultdict
from typing import (
    ClassVar,
    DefaultDict,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    Union,
)

from graia.amnesia.message import Element, MessageChain, Text
from graia.broadcast.builtin.derive import Derive
from graia.broadcast.entities.decorator import Decorator
from graia.broadcast.entities.dispatcher import BaseDispatcher
from graia.broadcast.exceptions import ExecutionStop
from graia.broadcast.interfaces.decorator import DecoratorInterface
from graia.broadcast.interfaces.dispatcher import DispatcherInterface
from typing_extensions import get_args

from ._typing_util import generic_issubclass, is_subclass, is_union
from ._util import map_chain, unmap_chain


class ChainDecorator(abc.ABC, Decorator, Derive[MessageChain]):
    pre = True

    @abc.abstractmethod
    async def __call__(self, chain: MessageChain, interface: DispatcherInterface) -> Optional[MessageChain]:
        ...

    async def target(self, interface: DecoratorInterface):
        return await self(
            await interface.dispatcher_interface.lookup_param("message_chain", MessageChain, None),
            interface.dispatcher_interface,
        )


class DetectPrefix(ChainDecorator):
    """前缀检测器"""

    def __init__(self, prefix: Union[str, Iterable[str]]) -> None:
        """初始化前缀检测器.

        Args:
            prefix (Union[str, Iterable[str]]): 要匹配的前缀
        """
        self.prefix: List[str] = [prefix] if isinstance(prefix, str) else list(prefix)

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        for prefix in self.prefix:
            if chain.startswith(prefix):
                return chain.removeprefix(prefix).removeprefix(" ")

        raise ExecutionStop


class DetectSuffix(ChainDecorator):
    """后缀检测器"""

    def __init__(self, suffix: Union[str, Iterable[str]]) -> None:
        """初始化后缀检测器.

        Args:
            suffix (Union[str, Iterable[str]]): 要匹配的后缀
        """
        self.suffix: List[str] = [suffix] if isinstance(suffix, str) else list(suffix)

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        for suffix in self.suffix:
            if chain.endswith(suffix):
                return chain.removesuffix(suffix).removesuffix(" ")
        raise ExecutionStop


class ContainKeyword(ChainDecorator):
    """消息中含有指定关键字"""

    def __init__(self, keyword: str) -> None:
        """初始化

        Args:
            keyword (str): 关键字
        """
        self.keyword: str = keyword

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        if self.keyword not in chain:
            raise ExecutionStop
        return chain


class MatchContent(ChainDecorator):
    """匹配字符串 / 消息链"""

    def __init__(self, content: Union[str, MessageChain]) -> None:
        """初始化

        Args:
            content (Union[str, MessageChain]): 匹配内容
        """
        self.content: Union[str, MessageChain] = content

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        if isinstance(self.content, str) and str(chain) != self.content:
            raise ExecutionStop
        if isinstance(self.content, MessageChain) and chain != self.content:
            raise ExecutionStop
        return chain


class MatchRegex(ChainDecorator, BaseDispatcher):
    """匹配正则表达式"""

    def __init__(self, regex: str, flags: re.RegexFlag = re.RegexFlag(0), full: bool = True) -> None:
        """初始化匹配正则表达式.

        Args:
            regex (str): 正则表达式
            flags (re.RegexFlag): 正则表达式标志
            full (bool): 是否要求完全匹配, 默认为 True.
        """
        self.regex: str = regex
        self.flags: re.RegexFlag = flags
        self.pattern = re.compile(self.regex, self.flags)
        self.match_func = self.pattern.fullmatch if full else self.pattern.match

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        if not self.match_func(str(chain)):
            raise ExecutionStop
        return chain

    async def beforeExecution(self, interface: DispatcherInterface):
        chain: MessageChain = await interface.lookup_param("message_chain", MessageChain, None)
        _mapping_str, _map = map_chain(chain)
        if res := self.match_func(_mapping_str):
            interface.local_storage["__parser_regex_match_obj__"] = res
            interface.local_storage["__parser_regex_match_map__"] = _map
        else:
            raise ExecutionStop

    async def catch(self, interface: DispatcherInterface):
        if interface.annotation is re.Match:
            return interface.local_storage["__parser_regex_match_obj__"]


class RegexGroup(Decorator):
    """正则表达式组的标志
    以 `Annotated[MessageChain, RegexGroup("xxx")]` 的形式使用,
    或者作为 Decorator 使用.
    """

    def __init__(self, target: Union[int, str]) -> None:
        """初始化

        Args:
            target (Union[int, str]): 目标的组名或序号
        """
        self.assign_target = target

    async def __call__(self, _, interface: DispatcherInterface):
        _res: re.Match = interface.local_storage["__parser_regex_match_obj__"]
        match_group: Tuple[str] = _res.groups()
        match_group_dict: Dict[str, str] = _res.groupdict()
        origin: Optional[str] = None
        if isinstance(self.assign_target, str) and self.assign_target in match_group_dict:
            origin = match_group_dict[self.assign_target]
        elif isinstance(self.assign_target, int) and self.assign_target < len(match_group):
            origin = match_group[self.assign_target]

        return (
            unmap_chain(origin, interface.local_storage["__parser_regex_match_map__"]) if origin is not None else None
        )

    async def target(self, interface: DecoratorInterface):
        return self("", interface.dispatcher_interface)


class MatchTemplate(ChainDecorator):
    """模板匹配"""

    def __init__(self, template: List[Union[Type[Element], Element, str]]) -> None:
        """初始化

        Args:
            template (List[Union[Type[Element], Element]]): 匹配模板， 可以为 `Element` 类或其 `Union`, `str`, `Plain` 实例
        """
        self.template: List[Union[Tuple[Type[Element], ...], Element, str]] = []
        for pattern in template:
            if is_subclass(pattern, Text):
                pattern = "*"
            if isinstance(pattern, type):
                self.template.append((pattern,))
            elif is_union(pattern):  # Union
                assert not any(is_subclass(t, Text) for t in get_args(pattern)), "Leaving Text here leads to ambiguity"
                self.template.append(get_args(pattern))
            elif isinstance(pattern, Element) and not isinstance(pattern, Text):
                self.template.append(pattern)
            else:
                pattern = (
                    re.escape(pattern.text)
                    if isinstance(pattern, Text)
                    else fnmatch.translate(pattern)[:-2]  # truncating the ending \Z
                )
                if self.template and isinstance(self.template[-1], str):
                    self.template[-1] += pattern
                else:
                    self.template.append(pattern)

    def match(self, chain: MessageChain):
        """匹配消息链"""
        if len(self.template) != len(chain):
            return False
        for element, template in zip(chain, self.template):
            if isinstance(template, tuple) and not isinstance(element, template):
                return False
            elif isinstance(template, Element) and element != template:
                return False
            elif isinstance(template, str):
                if not isinstance(element, Text) or not re.match(template, element.text):
                    return False
        return True

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        if not self.match(chain):
            raise ExecutionStop
        return chain


class FuzzyMatch(ChainDecorator):
    """模糊匹配

    Warning:
        我们更推荐使用 FuzzyDispatcher 来进行模糊匹配操作, 因为其具有上下文匹配数量限制.
    """

    def __init__(self, template: str, min_rate: float = 0.6) -> None:
        """初始化

        Args:
            template (str): 模板字符串
            min_rate (float): 最小匹配阈值
        """
        self.template: str = template
        self.min_rate: float = min_rate

    def match(self, chain: MessageChain):
        """匹配消息链"""
        text_frags: List[str] = []
        for element in chain:
            if isinstance(element, Text):
                text_frags.append(element.text)
            else:
                text_frags.append(str(element))
        text = "".join(text_frags)
        matcher = difflib.SequenceMatcher(a=text, b=self.template)
        # return false when **any** ratio calc falls undef the rate
        if matcher.real_quick_ratio() < self.min_rate:
            return False
        if matcher.quick_ratio() < self.min_rate:
            return False
        return matcher.ratio() >= self.min_rate

    async def __call__(self, chain: MessageChain, _) -> Optional[MessageChain]:
        if not self.match(chain):
            raise ExecutionStop
        return chain


class FuzzyDispatcher(BaseDispatcher):
    scope_map: ClassVar[DefaultDict[str, List[str]]] = defaultdict(list)
    event_ref: ClassVar["Dict[int, Dict[str, Tuple[str, float]]]"] = {}

    def __init__(self, template: str, min_rate: float = 0.6, scope: str = "") -> None:
        """初始化

        Args:
            template (str): 模板字符串
            min_rate (float): 最小匹配阈值
            scope (str): 作用域
        """
        self.template: str = template
        self.min_rate: float = min_rate
        self.scope: str = scope
        self.scope_map[scope].append(template)

    async def beforeExecution(self, interface: DispatcherInterface):
        event = interface.event
        if id(event) not in self.event_ref:
            chain: MessageChain = await interface.lookup_param("message_chain", MessageChain, None)
            text_frags: List[str] = []
            for element in chain:
                if isinstance(element, Text):
                    text_frags.append(element.text)
                else:
                    text_frags.append(str(element))
            text = "".join(text_frags)
            matcher = difflib.SequenceMatcher()
            matcher.set_seq2(text)
            rate_calc = self.event_ref[id(event)] = {}
            weakref.finalize(event, lambda d: self.event_ref.pop(d), id(event))
            for scope, templates in self.scope_map.items():
                max_match: float = 0.0
                for template in templates:
                    matcher.set_seq1(template)
                    if matcher.real_quick_ratio() < max_match:
                        continue
                    if matcher.quick_ratio() < max_match:
                        continue
                    if matcher.ratio() < max_match:
                        continue
                    rate_calc[scope] = (template, matcher.ratio())
                    max_match = matcher.ratio()
        win_template, win_rate = self.event_ref[id(event)].get(self.scope, (self.template, 0.0))
        if win_template != self.template or win_rate < self.min_rate:
            raise ExecutionStop

    async def catch(self, i: DispatcherInterface) -> Optional[float]:
        event = i.event
        _, rate = self.event_ref[id(event)].get(self.scope, (self.template, 0.0))
        if generic_issubclass(float, i.annotation) and "rate" in i.name:
            return rate


StartsWith = DetectPrefix
EndsWith = DetectSuffix
