# Copyright 2015 Sean Vig
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import re
import textwrap
from typing import Iterator, List, Mapping, Match, Optional

HEAD_MSG = "# This file has been autogenerated by the pywayland scanner"

# Match the start of unordered lists
RE_DOC_LIST = re.compile(r"^(?P<list_head>\* |- )")
# Match the interface and function use in the docstrings
RE_DOC = r"(?<![_\w])(?P<interface>{})(?P<func>\.[a-z][_a-z]*(?:\(\))?)?(?=[^\w\*']|$)"
# The base import path for sphinx directives in docstrings
BASE_PATH = "pywayland.protocol"
# the width of the doc string
DOC_WIDTH = 79
# the size of the tab stops
TAB_STOP = 4


class Printer:
    def __init__(
        self,
        protocol: str,
        interface_name: Optional[str] = None,
        interface_imports: Optional[Mapping[str, str]] = None,
    ) -> None:
        """Base level printer object

        Allows for storing of lines to be output from the definition of a
        protocol.  Lines are added by directly calling the printer object.

        :param protocol:
            The name of the protocol that is currently being generated, used
            for determining import resolution.
        :param interface_name:
            The name of the interface that is being generated, used for
            determining import resolution.
        :param interface_imports:
            The map from the interface name to the protocol it is defined in,
            for resolving imports.
        """
        self._level = 0
        self._lines = [HEAD_MSG, ""]
        self._protocol_name = protocol
        self._interface_name = interface_name
        self._interface_imports = interface_imports

        self._re_doc = None
        if interface_imports is not None:
            interface_names = "|".join(sorted(interface_imports, key=len, reverse=True))
            self._re_doc = re.compile(RE_DOC.format(interface_names))

    def __call__(self, new_line: Optional[str] = None) -> None:
        """Add the new line to the printer

        :param new_line:
            The new line to add to the file, should be appropriately wrapped.
        """
        if new_line:
            self._lines.append((" " * TAB_STOP * self._level) + new_line)
        else:
            self._lines.append("")

    def docstring(self, docstring: str) -> None:
        """Add lines as docstrings

        In addition to the operations performed by :meth:`Printer.doc()`, will
        wrap text passed to it to the correct width.

        :param docstring:
            The docstring line to add to the printer.
        """
        docstring = self._parse_doc_line(docstring)

        # create a list of properly indented paragraphs
        paragraphs = []
        for paragraph in docstring.split("\n\n"):
            # try to detect and properly output lists
            doc_list_match = RE_DOC_LIST.search(paragraph.lstrip())

            base_width = DOC_WIDTH - TAB_STOP * self._level
            if doc_list_match is None:
                # not a list, just create the paragraph
                paragraphs.append(
                    textwrap.fill(paragraph, width=base_width, break_long_words=False)
                )
            else:
                # the list items are not always separated by paragraph breaks,
                # so parse each line and see if they denote list items
                start_list = doc_list_match.group("list_head")
                lines = paragraph.split("\n")
                current_list_item = [lines[0].strip()]
                list_items: List[str] = []
                for line in lines[1:]:
                    if line.strip().startswith(start_list):
                        # store the current list item
                        list_items.append(" ".join(current_list_item))
                        # start a new list item
                        current_list_item = [line.strip()]
                    else:
                        current_list_item.append(line.strip())
                # store the last list item
                list_items.append(" ".join(current_list_item))

                # add each list item as a new paragraph
                paragraphs.extend(
                    textwrap.fill(
                        list_item,
                        width=base_width,
                        subsequent_indent=" " * len(start_list),
                        break_long_words=False,
                    )
                    for list_item in list_items
                )

        wrapped = "\n\n".join(paragraphs)
        for line in wrapped.split("\n"):
            self(line)

    def doc(self, new_line: str) -> None:
        """Add lines as docstrings

        Performs additional massaging of strings, replacing references to other
        protocols and protocol methods with the appropriate Sphinx
        cross-reference.

        :param new_line:
            The new line to add to the printer output.
        """
        new_line = self._parse_doc_line(new_line)
        self(new_line)

    def _parse_doc_line(self, line: str) -> str:
        """Parse the given docstring input

        Perform all necessary class and function replacements that match the
        given set of interfaces that are known.
        """
        assert self._re_doc is not None

        line = self._re_doc.sub(self._doc_replace, line)
        return line

    def _doc_replace(self, match: Match) -> str:
        """Perform interface and function name replacement on the given match"""
        if match.group("func") is None:
            return self._doc_class_replace(match)
        return self._doc_funcs_replace(match)

    def _doc_class_replace(self, match: Match) -> str:
        """Build the sphinx doc class import

        :param match:
            The regex match for the given interface class.
        :returns:
            The string corresponding to the sphinx formatted class.
        """
        interface_name = match.group("interface")
        interface_class = "".join(x.capitalize() for x in interface_name.split("_"))

        if interface_name == self._interface_name:
            return ":class:`{}`".format(interface_class)

        if (
            self._interface_imports is not None
            and interface_name in self._interface_imports
        ):
            protocol_path = self._interface_imports[interface_name]
            return ":class:`~{base_path}.{iface}.{class_name}`".format(
                class_name=interface_class, base_path=BASE_PATH, iface=protocol_path,
            )

        return "`{}`".format(interface_class)

    def _doc_funcs_replace(self, match: Match) -> str:
        """Build the sphinx doc function definition

        :param match:
            The regex match for the given interface and function.
        :returns:
            The string corresponding to the sphinx formatted function.
        """
        interface_name = match.group("interface")
        function_name = match.group("func")
        interface_class = "".join(x.capitalize() for x in interface_name.split("_"))

        if interface_name == self._interface_name:
            return ":func:`{}{}()`".format(interface_class, function_name)

        if (
            self._interface_imports is not None
            and interface_name in self._interface_imports
        ):
            protocol_path = self._interface_imports[interface_name]
            return ":func:`{class_name}{func}() <{base_path}.{iface}.{class_name}{func}>`".format(
                class_name=interface_class,
                func=function_name,
                base_path=BASE_PATH,
                iface=protocol_path,
            )

        return "`{}{}()`".format(interface_class, function_name)

    @contextlib.contextmanager
    def indented(self) -> Iterator[None]:
        """Indent in a level in the context manager block"""
        self._level += 1
        yield
        self._level -= 1

    def write(self, f) -> None:
        """Write the lines added to the printer out to the given file"""
        for line in self._lines:
            f.write(line.encode())
            f.write(b"\n")
