import os
import shutil
import subprocess
import tarfile
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from glob import glob
from operator import attrgetter
from pathlib import Path
from typing import Dict, Literal, Optional, Union, List, Iterable

import orjson
import toml
from graphlib import TopologicalSorter
from loguru import logger

from pbt.config import PBTConfig
from pbt.diff import diff_db, Diff
from pbt.poetry import Poetry
from pbt.version import parse_version


class PackageType(str, Enum):
    Poetry = "poetry"


@dataclass
class Package:
    name: str
    type: PackageType
    dir: Path
    version: str
    # a list of patterns to be included in the final package
    include: List[str]
    # a list of patterns to be excluded in the final package
    exclude: List[str]
    dependencies: Dict[str, str]
    # list of packages that this package uses
    inter_dependencies: List["Package"]
    # list of packages that use the current package
    invert_inter_dependencies: List["Package"]

    def build(self, cfg: PBTConfig) -> bool:
        """Build the package if needed"""
        # check if package has been modified since the last built
        whl_file = self.get_wheel_file()
        with diff_db(self, cfg) as db:
            diff = Diff.from_local(db, self)
            if whl_file is not None:
                if not diff.is_modified(db):
                    logger.info(
                        "Skip package {} as the content does not change", self.name
                    )
                    return False

            try:
                self._build()
            finally:
                diff.save(db)
            return True

    def _build(self):
        logger.info("Build package {}", self.name)
        if (self.dir / "dist").exists():
            shutil.rmtree(str(self.dir / "dist"))
        subprocess.check_output(["poetry", "build"], cwd=str(self.dir))

    def publish(self):
        self.pkg_handler.publish()

    def compute_pip_hash(self, cfg: PBTConfig, no_build: bool = False) -> str:
        """Compute hash of the content of the package"""
        if not no_build:
            self.build(cfg)
        output = (
            subprocess.check_output(["pip", "hash", self.get_wheel_file()])
            .decode()
            .strip()
        )
        output = output[output.find("--hash=") + len("--hash=") :]
        assert output.startswith("sha256:")
        return output[len("sha256:") :]

    def install_dep(
        self,
        package: "Package",
        cfg: PBTConfig,
        editable: bool = False,
        no_build: bool = False,
    ):
        """Install the `package` in the virtual environment of this package (`self`) (i.e., install dependency)"""
        if not no_build:
            package.build(cfg)

        logger.info(
            "Current package {}: install dependency {}", self.name, package.name
        )
        pipfile = self.pkg_handler.pip_path
        subprocess.check_output([pipfile, "uninstall", "-y", package.name])

        if editable:
            with tarfile.open(package.get_tar_file(), "r") as g:
                member = g.getmember(f"{package.name}-{package.version}/setup.py")
                with open(package.dir / "setup.py", "wb") as f:
                    f.write(g.extractfile(member).read())
            subprocess.check_call([pipfile, "install", "-e", "."], cwd=package.dir)
        else:
            subprocess.check_output([pipfile, "install", package.get_wheel_file()])

    def reload(self):
        pkg = load_package(self.dir)
        self.name = pkg.name
        self.version = pkg.version
        self.dependencies = pkg.dependencies
        self.include = pkg.include
        self.exclude = pkg.exclude
        assert all(dep.name in self.dependencies for dep in self.inter_dependencies)

    def is_package_compatible(self, package: "Package") -> bool:
        return self.pkg_handler.is_version_compatible(
            package.version, self.dependencies[package.name]
        )

    def all_inter_dependencies(self) -> Dict[str, "Package"]:
        """Get all inter dependencies of a package. It won't warn if there is any cycle."""
        stack = [self]
        deps = {}
        while len(stack) > 0:
            ptr = stack.pop()
            for dep in ptr.inter_dependencies:
                if dep.name not in deps:
                    stack.append(dep)
                    deps[dep.name] = dep
        return deps

    def next_version(self, rule: Literal["major", "minor", "patch"]):
        self.version = str(parse_version(self.version).next_version(rule))
        self.pkg_handler.replace_version()

    def update_package_version(self, package: "Package"):
        """Update the version of another package in this package"""
        assert package.name in self.dependencies
        self.pkg_handler.update_inter_dependency(package.name, package.version)

    @cached_property
    def pkg_handler(self):
        if self.type == PackageType.Poetry:
            return Poetry(self)
        raise NotImplementedError(self.type)

    def get_tar_file(self) -> Optional[str]:
        tar_file = self.dir / "dist" / f"{self.name}-{self.version}.tar.gz"
        if tar_file.exists():
            return str(tar_file)
        return None

    def get_wheel_file(self) -> Optional[str]:
        whl_files = glob(str(self.dir / f"dist/{self.name.replace('-', '_')}*.whl"))
        if len(whl_files) == 0:
            return None
        return whl_files[0]

    def filter_included_files(self, files: List[str]) -> List[str]:
        """Filter files that will be included in the final distributed packages"""
        dir_paths = []
        patterns = []

        for pattern in self.include:
            pattern = os.path.join(self.dir, pattern)
            # TODO: detect the pattern better
            if "*" in pattern or "?" in pattern or ("[" in pattern and "]" in pattern):
                patterns.append(pattern)
            else:
                dir_paths.append(pattern)

        if len(patterns) > 0:
            raise NotImplementedError()

        output = []
        for file in files:
            if any(file.startswith(dpath) for dpath in dir_paths):
                output.append(file)
        return output

    def to_dict(self):
        return {
            "name": self.name,
            "type": self.type,
            "dir": self.dir,
            "version": self.version,
            "dependencies": self.dependencies,
            "inter_dependencies": [p.name for p in self.inter_dependencies],
            "invert_inter_dependencies": [
                p.name for p in self.invert_inter_dependencies
            ],
        }

    @staticmethod
    def save(packages: Dict[str, "Package"], outfile: Union[str, Path]):
        with open(str(outfile), "wb") as f:
            f.write(
                orjson.dumps(
                    {package.name: package.to_dict() for package in packages.values()}
                )
            )

    @staticmethod
    def load(infile: str) -> Dict[str, "Package"]:
        with open(infile, "r") as f:
            raw_packages = orjson.loads(f.read())
        packages = {name: Package(**o) for name, o in raw_packages}
        for package in packages:
            package.type = PackageType(package.type)
            packages.inter_dependencies = [
                packages[name] for name in package.inter_dependencies
            ]
            packages.invert_inter_dependencies = [
                packages[name] for name in package.invert_inter_dependencies
            ]
        return packages

    def __repr__(self) -> str:
        return f"{self.name}={self.version}"


def search_packages(pbt_cfg: PBTConfig) -> Dict[str, Package]:
    logger.info("Search packages...")
    pkgs = {}

    for poetry_file in glob(str(pbt_cfg.cwd / "*/pyproject.toml")):
        pkg = load_package(Path(poetry_file).parent)
        logger.info("Found package {}", pkg.name)
        pkgs[pkg.name] = pkg

    pkg_names = set(pkgs.keys())
    for pkg in pkgs.values():
        for pname in pkg_names.intersection(pkg.dependencies.keys()):
            pkg.inter_dependencies.append(pkgs[pname])
            pkgs[pname].invert_inter_dependencies.append(pkg)
    for pkg in pkgs.values():
        pkg.inter_dependencies.sort(key=attrgetter("name"))
        pkg.invert_inter_dependencies.sort(key=attrgetter("name"))
    return pkgs


def load_package(pkg_dir: Path) -> Package:
    poetry_file = pkg_dir / "pyproject.toml"
    try:
        with open(poetry_file, "r") as f:
            project_cfg = toml.loads(f.read())
            pkg_name = project_cfg["tool"]["poetry"]["name"]
            pkg_version = project_cfg["tool"]["poetry"]["version"]
            pkg_dependencies = project_cfg["tool"]["poetry"]["dependencies"]

            # see https://python-poetry.org/docs/pyproject/#include-and-exclude
            # and https://python-poetry.org/docs/pyproject/#packages
            pkg_include = project_cfg["tool"]["poetry"].get("include", [])
            pkg_exclude = project_cfg["tool"]["poetry"].get("exclude", [])
            pkg_include.append(pkg_name)
            for pkg_cfg in project_cfg["tool"]["poetry"].get("packages", []):
                pkg_include.append(
                    os.path.join(pkg_cfg.get("from", ""), pkg_cfg["include"])
                )
            pkg_include = sorted(set(pkg_include))
    except:
        logger.error("Error while parsing configuration in {}", pkg_dir)
        raise

    return Package(
        name=pkg_name,
        type=PackageType.Poetry,
        version=pkg_version,
        dir=pkg_dir,
        include=pkg_include,
        exclude=pkg_exclude,
        dependencies=pkg_dependencies,
        inter_dependencies=[],
        invert_inter_dependencies=[],
    )


def topological_sort(packages: Dict[str, Package]) -> List[str]:
    """Sort the packages so that the first item is always leaf node in the dependency graph (i.e., it doesn't use any
    package in the repository.
    """
    graph = {}
    for package in packages.values():
        graph[package.name] = {child.name for child in package.inter_dependencies}
    return list(TopologicalSorter(graph).static_order())


def update_versions(
    updated_pkg_names: Iterable[str], packages: Dict[str, Package], force: bool = False
):
    for pkg_name in updated_pkg_names:
        pkg = packages[pkg_name]
        for parent_pkg in pkg.invert_inter_dependencies:
            if force or not parent_pkg.is_package_compatible(pkg):
                parent_pkg.update_package_version(pkg)
