from __future__ import annotations

import builtins
import os
import re
import shutil
from textwrap import dedent, indent
from pathlib import Path
from typing import (
    Generator,
    Iterable,
    List,
    Dict,
    Set,
    Union,
    Iterator,
    Tuple,
    Optional,
)
from dataclasses import dataclass

import black
import isort
from xmlschema import XMLSchema, qnames
from xmlschema.validators import (
    XsdAnyAttribute,
    XsdAnyElement,
    XsdAttribute,
    XsdElement,
    XsdType,
    XsdComponent,
)


# FIXME: Work out a better way to implement these override hacks.


@dataclass
class Override:
    type_: str
    default: Optional[str] = None
    imports: Optional[str] = None
    body: Optional[str] = None

    def __post_init__(self) -> None:
        if self.imports:
            self.imports = dedent(self.imports)
        if self.body:
            self.body = indent(dedent(self.body), " " * 4)


# Maps XSD TypeName to Override configuration, used to control output for that type.
OVERRIDES = {
    "MetadataOnly": Override(type_="bool", default="False"),
    # FIXME: Type should be xml.etree.ElementTree.Element but isinstance checks
    # with that class often mysteriously fail so the validator fails.
    "XMLAnnotation/Value": Override(type_="Any", imports="from typing import Any"),
    "BinData/Length": Override(type_="int"),
    # FIXME: hard-coded subclass lists
    "Instrument/LightSourceGroup": Override(
        type_="List[LightSource]",
        default="field(default_factory=list)",
        imports="""
            from typing import Dict, Union, Any
            from pydantic import validator
            from .light_source import LightSource
            from .laser import Laser
            from .arc import Arc
            from .filament import Filament
            from .light_emitting_diode import LightEmittingDiode
            from .generic_excitation_source import GenericExcitationSource

            _light_source_types: Dict[str, type] = {
                "laser": Laser,
                "arc": Arc,
                "filament": Filament,
                "light_emitting_diode": LightEmittingDiode,
                "generic_excitation_source": GenericExcitationSource,
            }
        """,
        body="""
            @validator("light_source_group", pre=True, each_item=True)
            def validate_light_source_group(
                cls, value: Union[LightSource, Dict[Any, Any]]
            ) -> LightSource:
                if isinstance(value, LightSource):
                    return value
                elif isinstance(value, dict):
                    try:
                        _type = value.pop("_type")
                    except KeyError:
                        raise ValueError(
                            "dict initialization requires _type"
                        ) from None
                    try:
                        light_source_cls = _light_source_types[_type]
                    except KeyError:
                        raise ValueError(
                            f"unknown LightSource type '{_type}'"
                        ) from None
                    return light_source_cls(**value)
                else:
                    raise ValueError("invalid type for light_source_group values")
        """,
    ),
    "ROI/Union": Override(
        type_="List[Shape]",
        default="field(default_factory=list)",
        imports="""
            from typing import Dict, Union, Any
            from pydantic import validator
            from .shape import Shape
            from .point import Point
            from .line import Line
            from .rectangle import Rectangle
            from .ellipse import Ellipse
            from .polyline import Polyline
            from .polygon import Polygon
            from .mask import Mask
            from .label import Label

            _shape_types: Dict[str, type] = {
                "point": Point,
                "line": Line,
                "rectangle": Rectangle,
                "ellipse": Ellipse,
                "polyline": Polyline,
                "polygon": Polygon,
                "mask": Mask,
                "label": Label,
            }
        """,
        body="""
            @validator("union", pre=True, each_item=True)
            def validate_union(
                cls, value: Union[Shape, Dict[Any, Any]]
            ) -> Shape:
                if isinstance(value, Shape):
                    return value
                elif isinstance(value, dict):
                    try:
                        _type = value.pop("_type")
                    except KeyError:
                        raise ValueError(
                            "dict initialization requires _type"
                        ) from None
                    try:
                        shape_cls = _shape_types[_type]
                    except KeyError:
                        raise ValueError(f"unknown Shape type '{_type}'") from None
                    return shape_cls(**value)
                else:
                    raise ValueError("invalid type for union values")
        """,
    ),
    "OME/StructuredAnnotations": Override(
        type_="List[Annotation]",
        default="field(default_factory=list)",
        imports="""
            from typing import Dict, Union, Any
            from pydantic import validator
            from .annotation import Annotation
            from .boolean_annotation import BooleanAnnotation
            from .comment_annotation import CommentAnnotation
            from .double_annotation import DoubleAnnotation
            from .file_annotation import FileAnnotation
            from .list_annotation import ListAnnotation
            from .long_annotation import LongAnnotation
            from .tag_annotation import TagAnnotation
            from .term_annotation import TermAnnotation
            from .timestamp_annotation import TimestampAnnotation
            from .xml_annotation import XMLAnnotation

            _annotation_types: Dict[str, type] = {
                "boolean_annotation": BooleanAnnotation,
                "comment_annotation": CommentAnnotation,
                "double_annotation": DoubleAnnotation,
                "file_annotation": FileAnnotation,
                "list_annotation": ListAnnotation,
                "long_annotation": LongAnnotation,
                "tag_annotation": TagAnnotation,
                "term_annotation": TermAnnotation,
                "timestamp_annotation": TimestampAnnotation,
                "xml_annotation": XMLAnnotation,
            }
        """,
        body="""
            @validator("structured_annotations", pre=True, each_item=True)
            def validate_structured_annotations(
                cls, value: Union[Annotation, Dict[Any, Any]]
            ) -> Annotation:
                if isinstance(value, Annotation):
                    return value
                elif isinstance(value, dict):
                    try:
                        _type = value.pop("_type")
                    except KeyError:
                        raise ValueError(
                            "dict initialization requires _type"
                        ) from None
                    try:
                        annotation_cls = _annotation_types[_type]
                    except KeyError:
                        raise ValueError(f"unknown Annotation type '{_type}'") from None
                    return annotation_cls(**value)
                else:
                    raise ValueError("invalid type for annotation values")
        """,
    ),
    "TiffData/UUID": Override(
        type_="Optional[UUID]",
        default="None",
        imports="""
            from typing import Optional
            from .simple_types import UniversallyUniqueIdentifier

            @ome_dataclass
            class UUID:
                file_name: str
                value: UniversallyUniqueIdentifier
        """,
    ),
}


def black_format(text: str, line_length: int = 79) -> str:
    return black.format_str(text, mode=black.FileMode(line_length=line_length))


def sort_imports(text: str) -> str:
    return isort.SortImports(file_contents=text).output


def sort_types(el: XsdType) -> str:
    if not el.is_complex() and not el.base_type.is_restriction():
        return "    " + el.local_name.lower()
    return el.local_name.lower()


def sort_prop(prop: Member) -> str:
    return ("" if prop.default_val_str else "   ") + prop.format().lower()


def as_identifier(s: str) -> str:
    # Remove invalid characters
    _s = re.sub("[^0-9a-zA-Z_]", "", s)
    # Remove leading characters until we find a letter or underscore
    _s = re.sub("^[^a-zA-Z_]+", "", _s)
    if not _s:
        raise ValueError(f"Could not clean {s}: nothing left")
    return _s


_CAMEL_SNAKE_OVERRIDES = {"ROIs": "rois"}


def camel_to_snake(name: str) -> str:
    result = _CAMEL_SNAKE_OVERRIDES.get(name, None)
    if not result:
        # FIXME This part must be kept identical to the copy of this function in
        # the schema module. Ideally we would have one shared implementation but
        # currently there is a problem importing anything from ome_types if the
        # model code hasn't been generated yet. It should be fixable with a
        # little reorganization.
        # https://stackoverflow.com/a/1176023
        result = re.sub("([A-Z]+)([A-Z][a-z]+)", r"\1_\2", name)
        result = re.sub("([a-z0-9])([A-Z])", r"\1_\2", result)
        result = result.lower().replace(" ", "_")
    return result


def local_import(item_type: str) -> str:
    return f"from .{camel_to_snake(item_type)} import {item_type}"


def make_dataclass(component: Union[XsdComponent, XsdType]) -> List[str]:
    lines = ["from ome_types.dataclasses import ome_dataclass", ""]
    # FIXME: Refactor to remove BinData special-case.
    if component.local_name == "BinData":
        base_type = None
    elif isinstance(component, XsdType):
        base_type = component.base_type
    else:
        base_type = component.type.base_type

    if base_type and not hasattr(base_type, "python_type"):
        base_name = f"({base_type.local_name})"
        if base_type.is_complex():
            lines += [local_import(base_type.local_name)]
        else:
            lines += [f"from .simple_types import {base_type.local_name}"]
    else:
        base_name = ""

    base_members = set()
    _basebase = base_type
    while _basebase:
        base_members.update(set(iter_members(base_type)))
        _basebase = _basebase.base_type

    members = MemberSet(m for m in iter_members(component) if m not in base_members)
    lines += members.imports()
    lines += members.locals()

    cannot_have_required_args = base_type and members.has_non_default_args()
    if cannot_have_required_args:
        lines[0] += ", EMPTY"

    lines += ["@ome_dataclass", f"class {component.local_name}{base_name}:"]
    # FIXME: Refactor to remove BinData special-case.
    if component.local_name == "BinData":
        lines.append("    value: str")
    lines += members.lines(
        indent=1,
        force_defaults=" = EMPTY  # type: ignore"
        if cannot_have_required_args
        else None,
    )

    lines += members.body()

    return lines


def make_enum(component: XsdComponent, name: str = None) -> List[str]:
    name = name or component.local_name
    lines = ["from enum import Enum", ""]
    lines += [f"class {name}(Enum):"]
    enum_elems = list(component.elem.iter("enum"))
    facets = component.get_facet(qnames.XSD_ENUMERATION)
    members: List[Tuple[str, str]] = []
    if enum_elems:
        for el, value in zip(enum_elems, facets.enumeration):
            _name = el.attrib["enum"]
            if component.base_type.python_type.__name__ == "str":
                value = f'"{value}"'
            members.append((_name, value))
    else:
        for e in facets.enumeration:
            members.append((camel_to_snake(e), repr(e)))

    for n, v in sorted(members):
        lines.append(f"    {as_identifier(n).upper()} = {v}")
    return lines


facet_converters = {
    qnames.XSD_PATTERN: lambda f: [f"regex = re.compile(r'{f.regexps[0]}')"],
    qnames.XSD_MIN_INCLUSIVE: lambda f: [f"ge = {f.value}"],
    qnames.XSD_MIN_EXCLUSIVE: lambda f: [f"gt = {f.value}"],
    qnames.XSD_MAX_INCLUSIVE: lambda f: [f"le = {f.value}"],
    qnames.XSD_MAX_EXCLUSIVE: lambda f: [f"lt = {f.value}"],
    qnames.XSD_LENGTH: lambda f: [f"min_length = {f.value}", f"max_length = {f.value}"],
    qnames.XSD_MIN_LENGTH: lambda f: [f"min_length = {f.value}"],
    qnames.XSD_MAX_LENGTH: lambda f: [f"max_length = {f.value}"],
}


def iter_all_members(
    component: XsdComponent,
) -> Generator[Union[XsdElement, XsdAttribute], None, None]:
    for c in component.iter_components((XsdElement, XsdAttribute)):
        if c is component:
            continue
        yield c


def iter_members(
    component: Union[XsdElement, XsdType]
) -> Generator[Union[XsdElement, XsdAttribute], None, None]:
    if isinstance(component, XsdElement):
        for attr in component.attributes.values():
            if isinstance(attr, XsdAttribute):
                yield attr
        for elem in component.iterchildren():
            yield elem
    else:
        yield from iter_all_members(component)


class Member:

    # Stores plurals from all Members for later access.
    plurals_registry: Dict[Tuple[str, str], str] = {}

    def __init__(self, component: Union[XsdElement, XsdAttribute]):
        self.component = component
        assert not component.is_global()

    @property
    def identifier(self) -> str:
        if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
            return self.component.local_name
        name = camel_to_snake(self.component.local_name)
        if self.plural:
            plural = camel_to_snake(self.plural)
            Member.plurals_registry[(self.parent_name, name)] = plural
            name = plural
        ident = camel_to_snake(name)
        if not ident.isidentifier():
            raise ValueError(f"failed to make identifier of {self!r}")
        return ident

    @property
    def plural(self) -> Optional[str]:
        """Plural form of component name, if available."""
        if (
            isinstance(self.component, XsdElement)
            and self.component.is_multiple()
            and self.component.ref
            and self.component.ref.annotation
        ):
            appinfo = self.component.ref.annotation.appinfo
            assert len(appinfo) == 1, "unexpected multiple appinfo elements"
            plural = appinfo[0].find("xsdfu/plural")
            if plural is not None:
                return plural.text
        return None

    @property
    def type(self) -> XsdType:
        return self.component.type

    @property
    def is_enum_type(self) -> bool:
        return self.type.get_facet(qnames.XSD_ENUMERATION) is not None

    @property
    def is_builtin_type(self) -> bool:
        return hasattr(self.type, "python_type")

    @property
    def is_decimal(self) -> bool:
        return self.component.type.is_derived(
            self.component.schema.builtin_types()["decimal"]
        )

    @property
    def parent_name(self) -> str:
        """Local name of component's first named ancestor."""
        p = self.component.parent
        while not p.local_name and p.parent is not None:
            p = p.parent
        return p.local_name

    @property
    def key(self) -> str:
        name = f"{self.parent_name}/{self.component.local_name}"
        if name not in OVERRIDES and self.component.local_name in OVERRIDES:
            return self.component.local_name
        return name

    def locals(self) -> Set[str]:
        if self.key in OVERRIDES:
            return set()
        if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
            return set()
        if not self.type or self.type.is_global():
            return set()
        locals_: Set[str] = set()
        # FIXME: this bit is mostly hacks
        if self.type.is_complex() and self.component.ref is None:
            locals_.add("\n".join(make_dataclass(self.component)) + "\n")
        if self.type.is_restriction() and self.is_enum_type:
            locals_.add(
                "\n".join(make_enum(self.type, name=self.component.local_name)) + "\n"
            )
        return locals_

    def imports(self) -> Set[str]:
        if self.key in OVERRIDES:
            _imp = OVERRIDES[self.key].imports
            return set([_imp]) if _imp else set()
        if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
            return set(["from typing import Any"])
        imports = set()
        if not self.max_occurs:
            imports.add("from typing import List")
            if self.is_optional:
                imports.add("from dataclasses import field")
        elif self.is_optional:
            imports.add("from typing import Optional")
        if self.is_decimal:
            imports.add("from typing import cast")
        if self.type.is_datetime():
            imports.add("from datetime import datetime")
        if not self.is_builtin_type and self.type.is_global():
            # FIXME: hack
            if not self.type.local_name == "anyType":
                if self.type.is_complex():
                    imports.add(local_import(self.type.local_name))
                else:
                    imports.add(f"from .simple_types import {self.type.local_name}")

        if self.component.ref is not None:
            if self.component.ref.local_name not in OVERRIDES:
                imports.add(local_import(self.component.ref.local_name))

        return imports

    def body(self) -> str:
        if self.key in OVERRIDES:
            return OVERRIDES[self.key].body or ""
        return ""

    @property
    def type_string(self) -> str:
        """single type, without Optional, etc..."""
        if self.key in OVERRIDES:
            return OVERRIDES[self.key].type_
        if isinstance(self.component, (XsdAnyElement, XsdAnyAttribute)):
            return "Any"
        if self.component.ref is not None:
            assert self.component.ref.is_global()
            return self.component.ref.local_name

        if self.type.is_datetime():
            return "datetime"
        if self.is_builtin_type:
            return self.type.python_type.__name__

        if self.type.is_global():
            return self.type.local_name
        elif self.type.is_complex():
            return self.component.local_name

        if self.type.is_restriction():
            # enumeration
            enum = self.type.get_facet(qnames.XSD_ENUMERATION)
            if enum:
                return self.component.local_name
            if self.type.base_type.local_name == "string":
                return "str"
        return ""

    @property
    def full_type_string(self) -> str:
        """full type, like Optional[List[str]]"""
        if self.key in OVERRIDES and self.type_string:
            return f": {self.type_string}"
        type_string = self.type_string
        if not type_string:
            return ""
        if not self.max_occurs:
            type_string = f"List[{type_string}]"
        elif self.is_optional:
            type_string = f"Optional[{type_string}]"
        return f": {type_string}" if type_string else ""

    @property
    def default_val_str(self) -> str:
        if self.key in OVERRIDES:
            default = OVERRIDES[self.key].default
            return f" = {default}" if default else ""
        if not self.is_optional:
            return ""

        if not self.max_occurs:
            default_val = "field(default_factory=list)"
        else:
            default_val = self.component.default
            if default_val is not None:
                if self.is_enum_type:
                    default_val = f"{self.type_string}('{default_val}')"
                elif hasattr(builtins, self.type_string):
                    default_val = repr(getattr(builtins, self.type_string)(default_val))
                if self.is_decimal:
                    default_val = f"cast({self.type_string}, {default_val})"
            else:
                default_val = "None"
        return f" = {default_val}"

    @property
    def max_occurs(self) -> bool:
        return getattr(self.component, "max_occurs", 1)

    @property
    def is_optional(self) -> bool:
        # FIXME: hack.  doesn't fully capture the restriction
        if self.identifier == "id":
            return True
        if getattr(self.component.parent, "model", "") == "choice":
            return True
        if hasattr(self.component, "min_occurs"):
            return self.component.min_occurs == 0
        return self.component.is_optional()

    def __repr__(self) -> str:
        type_ = "element" if isinstance(self.component, XsdElement) else "attribute"
        return f"<Member {type_} {self.component.local_name}>"

    def format(self, force_default: str = None) -> str:
        default = self.default_val_str
        if force_default:
            default = default or force_default
        return f"{self.identifier}{self.full_type_string}{default}"


class MemberSet:
    def __init__(self, initial: Iterable[Member] = ()):
        self._members: Set[Member] = set()
        self.update(initial)

    def add(self, member: Member) -> None:
        if not isinstance(member, Member):
            member = Member(member)
        self._members.add(member)

    def update(self, members: Iterable[Member]) -> None:
        for member in members:
            self.add(member)

    def lines(self, indent: int = 1, force_defaults: str = None) -> List[str]:
        if not self._members:
            lines = ["    " * indent + "pass"]
        else:
            lines = [
                "    " * indent + m.format(force_defaults)
                for m in sorted(self._members, key=sort_prop)
            ]
        return lines

    def imports(self) -> List[str]:
        if self._members:
            return list(set.union(*[m.imports() for m in self._members]))
        return []

    def locals(self) -> List[str]:
        if self._members:
            return list(set.union(*[m.locals() for m in self._members]))
        return []

    def body(self) -> List[str]:
        if self._members:
            return [m.body() for m in self._members]
        return []

    def has_non_default_args(self) -> bool:
        return any(not m.default_val_str for m in self._members)

    @property
    def non_defaults(self) -> "MemberSet":
        return MemberSet(m for m in self._members if not m.default_val_str)

    def __iter__(self) -> Iterator[Member]:
        return iter(self._members)


class GlobalElem:
    def __init__(self, elem: Union[XsdElement, XsdType]):
        assert elem.is_global()
        self.elem = elem

    @property
    def type(self) -> XsdType:
        return self.elem if self.is_type else self.elem.type

    @property
    def is_complex(self) -> bool:
        if hasattr(self.type, "is_complex"):
            return self.type.is_complex()
        return False

    @property
    def is_element(self) -> bool:
        return isinstance(self.elem, XsdElement)

    @property
    def is_type(self) -> bool:
        return isinstance(self.elem, XsdType)

    @property
    def is_enum(self) -> bool:
        is_enum = bool(self.elem.get_facet(qnames.XSD_ENUMERATION) is not None)
        if is_enum:
            if not len(self.elem.facets) == 1:
                raise NotImplementedError("Unexpected enum with multiple facets")
        return is_enum

    def _simple_class(self) -> List[str]:
        if self.is_enum:
            return make_enum(self.elem)

        lines = []
        if self.type.base_type.is_restriction():
            parent = self.type.base_type.local_name
        else:
            # it's a restriction of a builtin
            pytype = self.elem.base_type.python_type.__name__
            parent = f"Constrained{pytype.title()}"
            lines.extend([f"from pydantic.types import {parent}", ""])
        lines.append(f"class {self.elem.local_name}({parent}):")

        members = []
        for key, facet in self.elem.facets.items():
            members.extend([f"    {line}" for line in facet_converters[key](facet)])
        lines.extend(members if members else ["    pass"])
        if any("re.compile" in m for m in members):
            lines = ["import re", ""] + lines
        return lines

    def _abstract_class(self) -> List[str]:
        # FIXME: ? this might be a bit of an OME-schema-specific hack
        # this seems to be how abstract is used in the OME schema
        for e in self.elem.iter_components():
            if e != self.elem:
                raise NotImplementedError(
                    "Don't yet know how to handle abstract class with sub-components"
                )

        subs = [
            el
            for el in self.elem.schema.elements.values()
            if el.substitution_group == self.elem.name
        ]

        if not subs:
            raise NotImplementedError(
                "Don't know how to handle abstract class without substitutionGroups"
            )

        for el in subs:
            if not el.type.is_extension() and el.type.base_type == self.elem.type:
                raise NotImplementedError(
                    "Expected all items in substitution group to extend "
                    f"the type {self.elem.type} of Abstract element {self.elem}"
                )

        sub_names = [el.local_name for el in subs]
        lines = ["from typing import Union"]
        lines.extend([local_import(n) for n in sub_names])
        lines += [local_import(self.elem.type.local_name)]
        lines += [f"{self.elem.local_name} = {self.elem.type.local_name}", ""]
        lines += [f"{self.elem.local_name}Type = Union[{', '.join(sub_names)}]"]
        return lines

    def lines(self) -> str:
        # FIXME: Refactor to remove BinData special-case.
        if not self.is_complex and self.elem.local_name != "BinData":
            lines = self._simple_class()
        elif self.elem.abstract:
            lines = self._abstract_class()
        else:
            lines = make_dataclass(self.elem)
        return "\n".join(lines)

    def format(self) -> str:
        return black_format(sort_imports(self.lines() + "\n"))

    def write(self, filename: str) -> None:
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        with open(filename, "w") as f:
            f.write(self.format())

    @property
    def fname(self) -> str:
        return f"{camel_to_snake(self.elem.local_name)}.py"


_this_dir = os.path.dirname(__file__)
# _url = os.path.join(_this_dir, "ome_types", "ome-2016-06.xsd")
_url = "https://www.openmicroscopy.org/Schemas/OME/2016-06/ome.xsd"
_target = os.path.join(_this_dir, "ome_types", "model")


def convert_schema(url: str = _url, target_dir: str = _target) -> None:
    print("Inspecting XML schema ...")
    if isinstance(url, Path):
        url = str(url)
    schema = XMLSchema(url)
    print("Building dataclasses ...")
    shutil.rmtree(target_dir, ignore_errors=True)
    init_imports = []
    simples: List[GlobalElem] = []
    for elem in sorted(schema.types.values(), key=sort_types):
        if elem.local_name in OVERRIDES:
            continue
        converter = GlobalElem(elem)
        if not elem.is_complex():
            simples.append(converter)
            continue
        targetfile = os.path.join(target_dir, converter.fname)
        init_imports.append((converter.fname, elem.local_name))
        converter.write(filename=targetfile)

    for elem in schema.elements.values():
        if elem.local_name in OVERRIDES:
            continue
        converter = GlobalElem(elem)
        targetfile = os.path.join(target_dir, converter.fname)
        init_imports.append((converter.fname, elem.local_name))
        converter.write(filename=targetfile)

    text = "\n".join([s.format() for s in simples])
    text = black_format(sort_imports(text))
    with open(os.path.join(target_dir, "simple_types.py"), "w") as f:
        f.write(text)

    text = ""
    for fname, classname in init_imports:
        text += local_import(classname) + "\n"
    text = sort_imports(text)
    text += f"\n\n__all__ = [{', '.join(sorted(repr(i[1]) for i in init_imports))}]"
    # FIXME This could probably live somewhere else less visible to end-users.
    text += "\n\n_field_plurals = " + repr(Member.plurals_registry)
    text = black_format(text)
    with open(os.path.join(target_dir, f"__init__.py"), "w") as f:
        f.write(text)


if __name__ == "__main__":
    # for testing
    convert_schema()
