from .buffer import RenderBuffer
from .context import (
    HeaderContext,
    ClassContext,
    ClassTemplateData,
    FunctionContext,
    TrampolineData,
)
from .mangle import trampoline_signature

from . import render_pybind11 as rpybind11


def precomma(v: str) -> str:
    return f", {v}" if v else ""


def postcomma(v: str) -> str:
    return f"{v}, " if v else ""


def using_signature(cls: ClassContext, fn: FunctionContext) -> str:
    return f"{cls.full_cpp_name_identifier}_{fn.cpp_name}"


def render_cls_rpy_include_hpp(ctx: HeaderContext, cls: ClassContext) -> str:
    """
    Pieces that go into an rpy-include file for a class

    - Trampoline base class (if applicable)
    - Template constructors/method fillers (if applicable)
    """

    r = RenderBuffer()
    r.writeln(
        "// This file is autogenerated. DO NOT EDIT\n"
        "\n"
        "#pragma once\n"
        "#include <robotpy_build.h>"
    )

    if ctx.extra_includes_first:
        r.writeln()
        for inc in ctx.extra_includes_first:
            r.writeln(f"#include <{inc}>")

    r.writeln(f"\n#include <{ctx.rel_fname}>")

    if ctx.extra_includes:
        r.writeln()
        for inc in ctx.extra_includes:
            r.writeln(f"#include <{inc}>")

    if cls.trampoline is not None:
        _render_cls_trampoline(r, ctx, cls, cls.trampoline)

    if cls.template is not None:
        _render_cls_template_impl(r, ctx, cls, cls.template)

    return r.getvalue()


def _render_cls_trampoline(
    r: RenderBuffer, hctx: HeaderContext, cls: ClassContext, trampoline: TrampolineData
):
    """
    Generate trampoline classes to be used for two purposes:

    * Allow python programs to override virtual functions
    * Allow python programs access to protected members

    This trampoline is used from two different places:
    - To generate a trampoline usable by the class itself
    - Generate a trampoline usable by child classes

    Sometimes these are the same trampoline. The exception is when
    a 'final' method is in the base class, then a separate

    Each trampoline type is placed in a different namespace
    to make our life easier.

    Trampoline functions can be disabled via RPY_DISABLE_name_[type_type..]
    """

    if cls.template:
        template_argument_list = cls.template.argument_list
        template_parameter_list = cls.template.parameter_list
    else:
        template_argument_list = ""
        template_parameter_list = ""

    # delete specified methods
    if trampoline.methods_to_disable:
        r.writeln()
        for fn in trampoline.methods_to_disable:
            r.writeln(f"#define RPYGEN_DISABLE_{ trampoline_signature(fn) }")

    # include override files for each base -- TODO: exclude some bases?
    if cls.bases:
        r.writeln()
        for base in cls.bases:
            r.writeln(f"#include <rpygen/{ base.full_cpp_name_identifier }.hpp>")

    if cls.namespace:
        r.writeln(f"\nnamespace {cls.namespace.strip('::')} {{")

    if hctx.using_declarations:
        r.writeln()
        for decl in hctx.using_declarations:
            r.writeln(f"using {decl.format()};")

    #
    # Each trampoline has a configuration struct.
    #
    # - Stores the base class that the trampoline is wrapping
    # - Provides a mechanism to detect which base class to use when calling an
    #   overloaded virtual function (each class defines the overloads they have,
    #   and so if it's not defined in this config, then it falls back to the
    #   parent configuration)
    #

    r.writeln(
        f"\ntemplate <{postcomma(template_parameter_list)}typename CfgBase = rpygen::EmptyTrampolineCfg>"
    )

    if cls.bases:
        r.writeln(f"struct PyTrampolineCfg_{cls.cpp_name} :")

        with r.indent():
            for base in cls.bases:
                r.writeln(
                    f"{base.namespace_}PyTrampolineCfg_{base.cls_name}<{postcomma(base.template_params)}"
                )

            r.writeln("CfgBase")

            for base in cls.bases:
                r.writeln(">")
    else:
        r.writeln(f"struct PyTrampolineCfg_{cls.cpp_name} : CfgBase")

    r.writeln("{")

    with r.indent():
        r.writeln(f"using Base = {cls.full_cpp_name};\n")

        # specify base class to use for each virtual function
        for fn in trampoline.virtual_methods:
            r.writeln(
                f"using override_base_{ trampoline_signature(fn) } = { cls.full_cpp_name };"
            )

    r.writeln("};")

    if cls.bases:
        # To avoid multiple inheritance here, we define a single base with bases that
        # are all template bases..
        #
        # PyTrampolineBase is another trampoline or our base class
        r.writeln()

        r.writeln(
            f"template <typename PyTrampolineBase{precomma(template_parameter_list)}, typename PyTrampolineCfg>"
        )
        r.writeln(f"using PyTrampolineBase_{cls.cpp_name} =")

        for base in cls.bases:
            r.rel_indent(2)
            r.writeln(f"{base.namespace_}PyTrampoline_{base.cls_name}<")

        with r.indent():
            r.writeln("PyTrampolineBase")

            for base in reversed(cls.bases):
                if base.template_params:
                    r.writeln(f", {base.template_params}")
                r.writeln(", PyTrampolineCfg>")
                r.rel_indent(-2)

        r.write_trim(
            f"""
            ;

            template <typename PyTrampolineBase{ precomma(template_parameter_list) }, typename PyTrampolineCfg>
            struct PyTrampoline_{ cls.cpp_name } : PyTrampolineBase_{ cls.cpp_name }<PyTrampolineBase{ precomma(template_argument_list) }, PyTrampolineCfg> {{
              using PyTrampolineBase_{ cls.cpp_name }<PyTrampolineBase{ precomma(template_argument_list) }, PyTrampolineCfg>::PyTrampolineBase_{ cls.cpp_name };
        """
        )

    else:
        r.writeln()
        r.write_trim(
            f"""
            template <typename PyTrampolineBase{ precomma(template_parameter_list) }, typename PyTrampolineCfg>
            struct PyTrampoline_{ cls.cpp_name } : PyTrampolineBase {{
              using PyTrampolineBase::PyTrampolineBase;
        """
        )

    with r.indent():
        for ccls in cls.child_classes:
            if not ccls.template:
                r.writeln(
                    f"using {ccls.cpp_name} [[maybe_unused]] = typename {ccls.full_cpp_name};"
                )

        for enum in cls.enums:
            if enum.cpp_name:
                r.writeln(
                    f"using {enum.cpp_name} [[maybe_unused]] = typename {enum.full_cpp_name};"
                )

        for typealias in cls.user_typealias:
            r.writeln(f"{typealias};")

        for typealias in cls.auto_typealias:
            r.writeln(f"{typealias};")

        if cls.constants:
            r.writeln()
            for name, constant in cls.constants:
                r.writeln(f"static constexpr auto {name} = {constant};")

        #
        # protected constructors -- only used by the direct child
        #

        for fn in trampoline.protected_constructors:
            r.writeln(
                f"\n#ifdef RPYGEN_ENABLE_{cls.full_cpp_name_identifier}_PROTECTED_CONSTRUCTORS"
            )
            with r.indent():
                all_decls = ", ".join(p.decl for p in fn.all_params)
                all_names = ", ".join(p.arg_name for p in fn.all_params)
                r.writeln(f"PyTrampoline_{cls.cpp_name}({all_decls}) :")

                if cls.bases:
                    r.writeln(
                        f"  PyTrampolineBase_{cls.cpp_name}<PyTrampolineBase{precomma(trampoline.tmpl_args)}, PyTrampolineCfg>({all_names})"
                    )
                else:
                    r.writeln(f"  PyTrampolineBase({all_names})")

                r.writeln("{}")
            r.writeln("#endif")

        #
        # virtual methods
        #

        for fn in trampoline.virtual_methods:
            _render_cls_trampoline_virtual_method(r, cls, fn)

        #
        # non-virtual protected methods/attributes
        #

        for fn in trampoline.non_virtual_protected_methods:
            r.writeln(f"\n#ifndef RPYBLD_DISABLE_{ trampoline_signature(fn) }")

            # hack to ensure we don't do 'using' twice' in the same class, while
            # also ensuring that the overrides can be selectively disabled by
            # child trampoline functions
            with r.indent():
                r.writeln(f"#ifndef RPYBLD_UDISABLE_{ using_signature(cls, fn) }")
                with r.indent():
                    r.write_trim(
                        f"""
                        using { cls.full_cpp_name }::{ fn.cpp_name };
                        #define RPYBLD_UDISABLE_{ using_signature(cls, fn) }
                    """
                    )
                r.writeln("#endif")
            r.writeln("#endif")

        if cls.protected_properties:
            r.writeln()
            for prop in cls.protected_properties:
                r.writeln(f"using {cls.full_cpp_name}::{prop.cpp_name};")

        if trampoline.inline_code:
            r.writeln()
            r.write_trim(trampoline.inline_code)

    r.writeln("};\n\n")

    if cls.namespace:
        r.writeln(f"}}; // namespace {cls.namespace}")


def _render_cls_trampoline_virtual_method(
    r: RenderBuffer, cls: ClassContext, fn: FunctionContext
):
    r.writeln(f"\n#ifndef RPYGEN_DISABLE_{ trampoline_signature(fn) }")
    with r.indent():

        all_decls = ", ".join(p.decl for p in fn.all_params)
        const = " const" if fn.const else ""
        decl = f"{fn.cpp_return_type} {fn.cpp_name}({all_decls}){const}{fn.ref_qualifiers} override {{"
        r.writeln(decl)

        with r.indent():
            if fn.trampoline_cpp_code:
                r.write_trim(fn.trampoline_cpp_code)
            elif fn.ignore_pure:
                r.writeln('throw std::runtime_error("not implemented");')

            else:
                # TODO: probably will break for things like out parameters, etc #}
                if fn.virtual_xform:
                    r.writeln(f"auto custom_fn = {fn.virtual_xform};")

                # We define a "LookupBase" and "CallBase" here because to find the python
                # override we need to use the actual class currently being overridden, but
                # to make the actual call we might need to use a base class.
                #
                # .. lots of duplication here, but it's worse without it

                r.writeln("using LookupBase = typename PyTrampolineCfg::Base;")

                all_names = ", ".join(p.arg_name for p in fn.all_params)
                all_vnames = ", ".join(p.virtual_call_name for p in fn.all_params)

                if fn.is_pure_virtual and fn.virtual_xform:
                    r.write_trim(
                        f"""
                        RPYBUILD_OVERRIDE_PURE_CUSTOM_NAME({cls.cpp_name}, PYBIND11_TYPE({fn.cpp_return_type}), LookupBase,
                          "{fn.py_name}", {fn.cpp_name}, {all_names});
                        """
                    )
                elif fn.is_pure_virtual:
                    r.write_trim(
                        f"""
                        RPYBUILD_OVERRIDE_PURE_NAME({cls.cpp_name}, PYBIND11_TYPE({fn.cpp_return_type}), LookupBase,
                          "{fn.py_name}", {fn.cpp_name}, {all_names});
                        """
                    )
                elif fn.virtual_xform:
                    r.write_trim(
                        f"""
                        using CxxCallBase = typename PyTrampolineCfg::override_base_{trampoline_signature(fn)};
                        RPYBUILD_OVERRIDE_CUSTOM_IMPL(PYBIND11_TYPE({fn.cpp_return_type}), LookupBase,
                          "{fn.py_name}", {fn.cpp_name}, {all_names});
                        return CxxCallBase::{fn.cpp_name}({all_vnames});
                        """
                    )
                else:
                    r.write_trim(
                        f"""
                        using CxxCallBase = typename PyTrampolineCfg::override_base_{trampoline_signature(fn)};
                        PYBIND11_OVERRIDE_IMPL(PYBIND11_TYPE({fn.cpp_return_type}), LookupBase,
                          "{fn.py_name}", {all_names});
                        return CxxCallBase::{fn.cpp_name}({all_vnames});
                        """
                    )

        r.writeln("}")
    r.writeln("#endif")


def _render_cls_template_impl(
    r: RenderBuffer, hctx: HeaderContext, cls: ClassContext, template: ClassTemplateData
):
    if hctx.type_caster_includes:
        r.writeln()
        for inc in hctx.type_caster_includes:
            r.writeln(f"#include <{inc}>")

    r.writeln("\nnamespace rpygen {")

    if cls.namespace:
        r.writeln(f"\nusing namespace {cls.namespace};")

    if hctx.using_declarations:
        r.writeln()
        for decl in hctx.using_declarations:
            r.writeln(f"using {decl.format()};")

    r.writeln(f"\ntemplate <{template.parameter_list}>")
    r.writeln(f"struct bind_{cls.full_cpp_name_identifier} {{")

    with r.indent():
        rpybind11.cls_user_using(r, cls)
        rpybind11.cls_auto_using(r, cls)
        rpybind11.cls_consts(r, cls)
        rpybind11.cls_decl(r, cls)

        r.writeln("\npy::module &m;\nstd::string clsName;")

        r.writeln(
            f"bind_{cls.full_cpp_name_identifier}(py::module &m, const char * clsName) :"
        )

        with r.indent():

            # TODO: embedded structs will fail here
            rpybind11.cls_init(r, cls, "clsName")
            r.writeln("m(m),")
            r.writeln("clsName(clsName) {")

            rpybind11.cls_def_enum(r, cls, cls.var_name)

        r.write_trim(
            """
            }
            
            void finish(const char * set_doc = NULL, const char * add_doc = NULL) {
            """
        )
        with r.indent():
            rpybind11.cls_def(r, cls, cls.var_name)

            r.write_trim(
                f"""
                if (set_doc) {{
                  {cls.var_name}.doc() = set_doc;
                }}
                if (add_doc) {{
                  {cls.var_name}.doc() = py::cast<std::string>({cls.var_name}.doc()) + add_doc;
                }}
            """
            )

            if template.inline_code:
                r.writeln()
                r.write_trim(template.inline_code)

        r.writeln("}")

    r.write_trim(
        f"""
        }}; // struct bind_{cls.full_cpp_name_identifier}

        }}; // namespace rpygen
        """
    )
