# pylint: disable=missing-module-docstring
from typing import Callable, Optional

from phml.nodes import AST, Element, Root
from phml.utils.validate.test import Test


class Index:
    """Uses the given key or key generator and creates a mutable dict of key value pairs
    that can be easily indexed.

    Nodes that don't match the condition or don't have a valid key are not indexed.
    """

    indexed_tree: dict[str, list[Element]]
    """The indexed collection of elements"""

    def __init__(
        self, key: str | Callable, start: AST | Root | Element, condition: Optional[Test] = None
    ):
        """
        Args:
            `key` (str | Callable): Str represents the property to use as an index. Callable
            represents a function to call on each element to generate a key. The returned key
            must be able to be converted to a string. If none then element is skipped.
            `start` (AST | Root | Element): The root or node to start at while indexing
            `test` (Test): The test to apply to each node. Only valid/passing nodes
            will be indexed
        """
        from phml.utils import test, walk  # pylint: disable=import-outside-toplevel

        self.indexed_tree = {}
        self.key = key

        for node in walk(start):
            if isinstance(node, Element):
                if condition is not None:
                    if test(node, condition):
                        self.add(node)
                else:
                    self.add(node)

    def add(self, node: Element):
        """Adds element to indexed collection if not already there."""

        key = node.properties[self.key] if isinstance(self.key, str) else self.key(node)
        if key not in self.indexed_tree:
            self.indexed_tree[key] = [node]

        if node not in self.indexed_tree[key]:
            self.indexed_tree[key].append(node)

    def remove(self, node: Element):
        """Removes element from indexed collection if there."""

        key = node.properties[self.key] if isinstance(self.key, str) else self.key(node)
        if key in self.indexed_tree and node in self.indexed_tree[key]:
            self.indexed_tree[key].remove(node)

    def get(self, _key: str) -> Optional[list[Element]]:
        """Get a specific index from the indexed tree."""
        return self.indexed_tree.get(_key)

    def map(self, modifier: Callable) -> list:
        """Applies the passed modifier to each index.

        Returns:
            list of results generated by the modifier applied
            to each index.
        """
        result = []
        for value in self.indexed_tree.values():
            result.extend([modifier(v) for v in value])
        return result
