from argparse import ArgumentParser
from pathlib import Path
import re
import sys
from tempfile import NamedTemporaryFile
import textwrap
import typing

from mako.pygen import PythonPrinter

sys.path.append(str(Path(__file__).parent.parent))

if True:  # avoid flake/zimports messing with the order
    from alembic.operations.base import Operations
    from alembic.runtime.environment import EnvironmentContext
    from alembic.script.write_hooks import console_scripts
    from alembic.util.compat import inspect_formatargspec
    from alembic.util.compat import inspect_getfullargspec
    from alembic.operations import ops
    import sqlalchemy as sa

IGNORE_ITEMS = {
    "op": {"context", "create_module_class_proxy"},
    "context": {
        "create_module_class_proxy",
        "get_impl",
        "requires_connection",
    },
}
TRIM_MODULE = [
    "alembic.runtime.migration.",
    "alembic.operations.ops.",
    "sqlalchemy.engine.base.",
    "sqlalchemy.sql.schema.",
    "sqlalchemy.sql.selectable.",
    "sqlalchemy.sql.elements.",
    "sqlalchemy.sql.type_api.",
    "sqlalchemy.sql.functions.",
    "sqlalchemy.sql.dml.",
]


def generate_pyi_for_proxy(
    cls: type,
    progname: str,
    source_path: Path,
    destination_path: Path,
    ignore_output: bool,
    ignore_items: set,
):
    if sys.version_info < (3, 9):
        raise RuntimeError("This script must be run with Python 3.9 or higher")

    # When using an absolute path on windows, this will generate the correct
    # relative path that shall be written to the top comment of the pyi file.
    if Path(progname).is_absolute():
        progname = Path(progname).relative_to(Path().cwd()).as_posix()

    imports = []
    read_imports = False
    with open(source_path) as read_file:
        for line in read_file:
            if line.startswith("# ### this file stubs are generated by"):
                read_imports = True
            elif line.startswith("### end imports ###"):
                read_imports = False
                break
            elif read_imports:
                imports.append(line.rstrip())

    with open(destination_path, "w") as buf:
        printer = PythonPrinter(buf)

        printer.writeline(
            f"# ### this file stubs are generated by {progname} "
            "- do not edit ###"
        )
        for line in imports:
            buf.write(line + "\n")
        printer.writeline("### end imports ###")
        buf.write("\n\n")

        module = sys.modules[cls.__module__]
        env = {
            **sa.__dict__,
            **sa.types.__dict__,
            **ops.__dict__,
            **module.__dict__,
        }

        for name in dir(cls):
            if name.startswith("_") or name in ignore_items:
                continue
            meth = getattr(cls, name, None)
            if callable(meth):
                _generate_stub_for_meth(cls, name, printer, env)
            else:
                _generate_stub_for_attr(cls, name, printer, env)

        printer.close()

    console_scripts(
        str(destination_path),
        {"entrypoint": "zimports", "options": "-e"},
        ignore_output=ignore_output,
    )
    # note that we do not distribute pyproject.toml with the distribution
    # right now due to user complaints, so we can't refer to it here because
    # this all has to run as part of the test suite
    console_scripts(
        str(destination_path),
        {"entrypoint": "black", "options": "-l79"},
        ignore_output=ignore_output,
    )


def _generate_stub_for_attr(cls, name, printer, env):
    try:
        annotations = typing.get_type_hints(cls, env)
    except NameError as e:
        annotations = cls.__annotations__
    type_ = annotations.get(name, "Any")
    if isinstance(type_, str) and type_[0] in "'\"":
        type_ = type_[1:-1]
    printer.writeline(f"{name}: {type_}")


def _generate_stub_for_meth(cls, name, printer, env):

    fn = getattr(cls, name)
    while hasattr(fn, "__wrapped__"):
        fn = fn.__wrapped__

    spec = inspect_getfullargspec(fn)
    try:
        annotations = typing.get_type_hints(fn, env)
        spec.annotations.update(annotations)
    except NameError as e:
        pass

    name_args = spec[0]
    assert name_args[0:1] == ["self"] or name_args[0:1] == ["cls"]

    name_args[0:1] = []

    def _formatannotation(annotation, base_module=None):
        if getattr(annotation, "__module__", None) == "typing":
            retval = repr(annotation).replace("typing.", "")
        elif isinstance(annotation, type):
            if annotation.__module__ in ("builtins", base_module):
                retval = annotation.__qualname__
            else:
                retval = annotation.__module__ + "." + annotation.__qualname__
        else:
            retval = annotation

        for trim in TRIM_MODULE:
            retval = retval.replace(trim, "")

        retval = re.sub(
            r'ForwardRef\(([\'"].+?[\'"])\)', lambda m: m.group(1), retval
        )
        retval = re.sub("NoneType", "None", retval)
        return retval

    argspec = inspect_formatargspec(
        *spec,
        formatannotation=_formatannotation,
        formatreturns=lambda val: f"-> {_formatannotation(val)}",
    )

    func_text = textwrap.dedent(
        """\
    def %(name)s%(argspec)s:
        '''%(doc)s'''
    """
        % {
            "name": name,
            "argspec": argspec,
            "doc": fn.__doc__,
        }
    )

    printer.write_indented_block(func_text)


def run_file(
    source_path: Path, cls_to_generate: type, stdout: bool, ignore_items: set
):
    progname = Path(sys.argv[0]).as_posix()
    if not stdout:
        generate_pyi_for_proxy(
            cls_to_generate,
            progname,
            source_path=source_path,
            destination_path=source_path,
            ignore_output=False,
            ignore_items=ignore_items,
        )
    else:
        with NamedTemporaryFile(delete=False, suffix=".pyi") as f:
            f.close()
            f_path = Path(f.name)
            generate_pyi_for_proxy(
                cls_to_generate,
                progname,
                source_path=source_path,
                destination_path=f_path,
                ignore_output=True,
                ignore_items=ignore_items,
            )
            sys.stdout.write(f_path.read_text())
        f_path.unlink()


def main(args):
    location = Path(__file__).parent.parent / "alembic"
    if args.file in {"all", "op"}:
        run_file(
            location / "op.pyi", Operations, args.stdout, IGNORE_ITEMS["op"]
        )
    if args.file in {"all", "context"}:
        run_file(
            location / "context.pyi",
            EnvironmentContext,
            args.stdout,
            IGNORE_ITEMS["context"],
        )


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument(
        "--file",
        choices={"op", "context", "all"},
        default="all",
        help="Which file to generate. Default is to regenerate all files",
    )
    parser.add_argument(
        "--stdout",
        action="store_true",
        help="Write to stdout instead of saving to file",
    )
    args = parser.parse_args()
    main(args)
