# -*- coding: utf-8 -*-
# code generated by Prisma. DO NOT EDIT.
# pylint: disable=all
# pyright: reportUnusedImport=false
# fmt: off

# global imports for type checking
import sys
import datetime
from typing import (
    TYPE_CHECKING,
    Optional,
    Iterable,
    Iterator,
    Mapping,
    Tuple,
    Union,
    List,
    Dict,
    Type,
    Any,
    Set,
    overload,
    cast,
)

if sys.version_info >= (3, 8):
    from typing import TypedDict, Literal
else:
    from typing_extensions import TypedDict, Literal

# -- template builder.py.jinja --

# TODO: the QueryBuilder should validate and add type information context.
#       currently we just naively iterate through arguments and encode them
#       using standard json when we don't have any special casing for it.
#       this makes it more difficult to add support for non-standard types
#       such as the `Json` type.
# TODO: optimise for performance (switch to c / cython?)
# TODO: pass context around differently, relying on the builder instance is
#       not ideal, contetx should be local to each node


import json
import logging
import inspect
from textwrap import indent
from datetime import timezone
from abc import abstractmethod, ABC
from functools import singledispatch

from ._types import Serializable
from .errors import UnknownModelError, UnknownRelationalFieldError


log: logging.Logger = logging.getLogger(__name__)

ChildType = Union['AbstractNode', str]


GLOBAL_ALIASES = {
    'startswith': 'startsWith',
    'endswith': 'endsWith',
    'order_by': 'orderBy',
    'not_in': 'notIn',
    'is_not': 'isNot',
    'NOT': 'not',
    'IN': 'in',
}

DEFAULT_FIELDS_MAPPING = {
    'Post': [
        'id',
        'created_at',
        'updated_at',
        'title',
        'published',
        'views',
        'desc',
        'author_id',
    ],
    'User': [
        'id',
        'name',
    ],
    'Category': [
        'id',
        'name',
    ],
    'Profile': [
        'id',
        'user_id',
        'bio',
    ],
}

RELATIONAL_FIELD_MAPPINGS = {
    'Post': {
        'author': 'User',
        'categories': 'Category',
    },
    'User': {
        'posts': 'Post',
        'profile': 'Profile',
    },
    'Category': {
        'posts': 'Post',
    },
    'Profile': {
        'user': 'User',
    },
}


class QueryBuilder:
    # prisma method
    method: str

    # GraphQL operation
    operation: str

    # prisma model
    model: Optional[str]

    # mapping of relational fields to include in the result
    # NOTE: this should be a recursive type
    # IncludeType = Union[bool, Dict[str, 'IncludeType']]
    include: Optional[Dict[str, Any]]

    # arguments to pass to the query
    arguments: Dict[str, Any]

    # list of fields to select
    root_selection: Optional[List[str]]

    def __init__(
        self,
        *,
        method: str,
        operation: str,
        arguments: Dict[str, Any],
        model: Optional[str] = None,
        root_selection: Optional[List[str]] = None
    ) -> None:
        self.model = model
        self.method = method
        self.operation = operation
        self.root_selection = root_selection
        self.arguments = args = self._transform_aliases(arguments)
        self.include = args.pop('include', None)

    def build(self) -> str:
        """Build the payload that should be sent to the QueryEngine"""
        data = {
            'variables': {},
            'operation_name': self.operation,
            'query': self.build_query(),
        }
        return dumps(data)

    def build_query(self) -> str:
        """Build the GraphQL query

        Example query:

        query {
          result: findUniqueUser
          (
            where: {
              id: "ckq23ky3003510r8zll5m2hma"
            }
          )
          {
            id
            name
            profile {
              id
              user_id
              bio
            }
          }
        }
        """
        query = self._create_root_node().render()
        log.debug('Generated query: \n%s', query)
        return query

    def _create_root_node(self) -> 'RootNode':
        root = RootNode(builder=self)
        root.add(ResultNode.create(self))
        root.add(
            Selection.create(
                self,
                model=self.model,
                include=self.include,
                root_selection=self.root_selection,
            )
        )
        return root

    def get_default_fields(self, model: str) -> List[str]:
        """Returns a list of all the scalar fields of a model

        Raises UnknownModelError if the current model cannot be found.
        """
        try:
            return DEFAULT_FIELDS_MAPPING[model].copy()
        except KeyError as exc:
            raise UnknownModelError(model) from exc

    def get_relational_model(self, current_model: str, field: str) -> str:
        """Returns the model that the field is related to.

        Raises UnknownModelError if the current model is invalid.
        Raises UnknownRelationalFieldError if the field does not exist.
        """
        try:
            mappings = RELATIONAL_FIELD_MAPPINGS[current_model]
        except KeyError as exc:
            raise UnknownModelError(current_model) from exc

        try:
            return mappings[field]
        except KeyError as exc:
            raise UnknownRelationalFieldError(model=current_model, field=field) from exc

    def _transform_aliases(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        """Transform dict keys to match global aliases

        e.g. order_by -> orderBy
        """
        transformed = dict()
        for key, value in arguments.items():
            alias = GLOBAL_ALIASES.get(key, key)
            if isinstance(value, dict):
                transformed[alias] = self._transform_aliases(arguments=value)
            else:
                transformed[alias] = value
        return transformed


class AbstractNode(ABC):
    @abstractmethod
    def render(self) -> Optional[str]:
        """Render the node to a string

        None is returned if the node should not be rendered.
        """
        ...

    def should_render(self) -> bool:
        """If True, rendering of the node is skipped

        Useful for some nodes as they should only actually
        be rendered if they have any children.
        """
        return True


class Node(AbstractNode):
    """Base node handling rendering of child nodes"""
    joiner: str
    indent: str
    builder: QueryBuilder
    children: List[ChildType]

    def __init__(
        self,
        builder: QueryBuilder,
        *,
        joiner: str = '\n',
        indent: str = '  ',
        children: Optional[List[ChildType]] = None
    ) -> None:
        self.builder = builder
        self.joiner = joiner
        self.indent = indent
        self.children = children if children is not None else []

    def enter(self) -> Optional[str]:
        """Get the string used to enter the node.

        This string will be rendered *before* the children.
        """
        return None

    def depart(self) -> Optional[str]:
        """Get the string used to depart the node.

        This string will be rendered *after* the children.
        """
        return None

    def render(self) -> Optional[str]:
        """Render the node and it's children and to string.

        Rendering a node involves 4 steps:

        1. Entering the node
        2. Rendering it's children
        3. Departing the node
        4. Joining the previous steps together into a single string
        """
        if not self.should_render():
            return None

        strings: List[str] = []
        entered = self.enter()
        if entered is not None:
            strings.append(entered)

        for child in self.children:
            content: Optional[str] = None

            if isinstance(child, str):
                content = child
            else:
                content = child.render()

            if content:
                strings.append(indent(content, self.indent))

        departed = self.depart()
        if departed is not None:
            strings.append(departed)

        return self.joiner.join(strings)

    def add(self, child: ChildType) -> None:
        """Add a child"""
        self.children.append(child)

    def create_children(self) -> List[ChildType]:
        """Create the node's children

        If children are passed to the constructor, the children
        returned from this method are used to extend the already
        set children.
        """
        return []

    @classmethod
    def create(cls, builder: Optional[QueryBuilder] = None, **kwargs: Any) -> 'Node':
        """Create the node and its children

        This is useful for subclasses that add extra attributes in __init__
        """
        kwargs.setdefault('builder', builder)
        node = cls(**kwargs)
        node.children.extend(node.create_children())
        return node


class RootNode(Node):
    """Rendered node examples:

    query {
        <children>
    }

    or

    mutation {
        <children>
    }
    """

    def enter(self) -> str:
        return f'{self.builder.operation} {{'

    def depart(self) -> str:
        return '}'

    def render(self) -> str:
        content = super().render()
        if not content:  # pragma: no cover
            # this should never happen.
            # render() is typed to return None if the node
            # should not be rendered but as this node will
            # always be rendered it should always return
            # a non-empty string.
            raise RuntimeError('Could not generate query.')
        return content


class ResultNode(Node):
    """Rendered node examples:

    result: findUniqueUser
        <children>

    or

    result: executeRaw
        <children>
    """
    def __init__(self, indent: str = '', **kwargs: Any) -> None:
        super().__init__(indent=indent, **kwargs)

    def enter(self) -> str:
        model = self.builder.model
        if model is not None:
            return f'result: {self.builder.method}{model}'

        return f'result: {self.builder.method}'

    def depart(self) -> Optional[str]:
        return None

    def create_children(self) -> List[ChildType]:
        return [
            Arguments.create(
                self.builder,
                arguments=self.builder.arguments,
            )
        ]


class Arguments(Node):
    """Rendered node example:

    (
        key1: "1"
        key2: "[\"John\",\"123\"]"
        key3: true
        key4: {
            data: true
        }
    )
    """
    arguments: Dict[str, Any]

    def __init__(self, arguments: Dict[str, Any], **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.arguments = arguments

    def should_render(self) -> bool:
        return bool(self.children)

    def enter(self) -> str:
        return '('

    def depart(self) -> str:
        return ')'

    def create_children(self, arguments: Optional[Dict[str, Any]] = None) -> List[ChildType]:
        children: List[ChildType] = []

        for arg, value in self.arguments.items():
            if value is None:
                # ignore None values for convenience
                continue

            if isinstance(value, dict):
                children.append(
                    Key(arg, node=Data.create(self.builder, data=value))
                )
            elif isinstance(value, (list, tuple, set)):
                # NOTE: we have a special case for execute_raw and query_raw
                # here as prisma expects parameters to be passed as a json string
                # value like "[\"John\",\"123\"]", and we encode twice to ensure
                # that only the inner quotes are escaped
                if self.builder.method in {'queryRaw', 'executeRaw'}:
                    children.append(f'{arg}: {dumps(dumps(value))}')
                else:
                    children.append(Key(arg, node=ListNode.create(self.builder, data=value)))
            else:
                children.append(f'{arg}: {dumps(value)}')

        return children


class Data(Node):
    """Rendered node example:

    {
        key1: "a"
        key2: 3
        key3: [
            "name"
        ]
    }
    """
    data: Mapping[str, Any]

    def __init__(
        self,
        data: Mapping[str, Any],
        **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.data = data

    def enter(self) -> str:
        return '{'

    def depart(self) -> str:
        return '}'

    def create_children(self) -> List[ChildType]:
        children: List[ChildType] = []

        for key, value in self.data.items():
            if isinstance(value, dict):
                children.append(
                    Key(key, node=Data.create(self.builder, data=value))
                )
            elif isinstance(value, (list, tuple, set)):
                children.append(
                    Key(key, node=ListNode.create(self.builder, data=value))
                )
            else:
                children.append(f'{key}: {dumps(value)}')

        return children


class ListNode(Node):
    data: Iterable[Any]

    def __init__(self, data: Iterable[Any], joiner: str = ',\n', **kwargs: Any) -> None:
        super().__init__(joiner=joiner, **kwargs)
        self.data = data

    def enter(self) -> str:
        return '['

    def depart(self) -> str:
        return ']'

    def create_children(self) -> List[ChildType]:
        children: List[ChildType] = []

        for item in self.data:
            if isinstance(item, dict):
                children.append(Data.create(self.builder, data=item))
            else:
                children.append(dumps(item))

        return children


class Selection(Node):
    """Represents field selections

    Example no include:

    {
        id
        name
    }

    Example include={'posts': True}

    {
        id
        name
        posts {
            id
            title
        }
    }

    Example include={'posts': {'where': {'title': {'contains': 'Test'}}}}

    {
        id
        name
        posts(
            where: {
                title: {
                    contains: 'Test'
                }
            }
        )
        {
            id
            title
        }
    }
    """
    model: Optional[str]
    include: Optional[Dict[str, Any]]
    root_selection: Optional[List[str]]

    def __init__(
        self,
        model: Optional[str] = None,
        include: Optional[Dict[str, Any]] = None,
        root_selection: Optional[List[str]] = None,
        **kwargs: Any
    ) -> None:
        super().__init__(**kwargs)
        self.model = model
        self.include = include
        self.root_selection = root_selection

    def should_render(self) -> bool:
        return bool(self.children)

    def enter(self) -> str:
        return '{'

    def depart(self) -> str:
        return '}'

    def create_children(self) -> List[ChildType]:
        model = self.model
        include = self.include
        builder = self.builder
        children: List[ChildType] = []

        # root_selection, if present overrides the default fields
        # for a model as it is used by methods such as count()
        # that do not support returning model fields
        root_selection = self.root_selection
        if root_selection is not None:
            children.extend(root_selection)
        elif model is not None:
            children.extend(builder.get_default_fields(model))

        if include is not None:
            if model is None:
                raise ValueError('Cannot include fields when model is None.')

            for key, value in include.items():
                if value is True:
                    # e.g. posts { post_fields }
                    children.append(
                        Key(
                            key,
                            sep=' ',
                            node=Selection.create(
                                builder,
                                include=None,
                                model=builder.get_relational_model(
                                    current_model=model, field=key
                                ),
                            ),
                        )
                    )
                elif isinstance(value, dict):
                    # e.g. given {'posts': {where': {'published': True}}} return
                    # posts( where: { published: true }) { post_fields }
                    args = value.copy()
                    nested_include = args.pop('include', None)
                    children.extend(
                        [
                            Key(
                                key,
                                sep='',
                                node=Arguments.create(
                                    builder, arguments=args
                                ),
                            ),
                            Selection.create(
                                builder,
                                include=nested_include,
                                model=builder.get_relational_model(
                                    current_model=model, field=key
                                ),
                            ),
                        ]
                    )
                elif value is False:
                    continue
                else:
                    raise TypeError(
                        f'Expected `bool` or `dict` include value but got {type(value)} instead.'
                    )

        return children


class Key(AbstractNode):
    """Node for rendering a child node with a prefixed key"""
    key: str
    sep: str
    node: Node

    def __init__(self, key: str, node: Node, sep: str = ': ') -> None:
        self.key = key
        self.node = node
        self.sep = sep

    def render(self) -> str:
        content = self.node.render()
        if content:
            return f'{self.key}{self.sep}{content}'
        return f'{self.key}{self.sep}'


@singledispatch
def serializer(obj: Any) -> Serializable:
    """Single dispatch generic function for serializing objects to JSON"""
    if inspect.isclass(obj):
        typ = obj
    else:
        typ = type(obj)

    raise TypeError(f'Type {typ} not serializable')


@serializer.register(datetime.datetime)
def serialize_datetime(dt: datetime.datetime) -> str:
    """Format a datetime object to an ISO8601 string with a timezone.

    This assumes naive datetime objects are in UTC.
    """
    if dt.tzinfo is None:
        dt = dt.replace(tzinfo=timezone.utc)
    elif dt.tzinfo != timezone.utc:
        dt = dt.astimezone(timezone.utc)

    return dt.isoformat()


def dumps(obj: Any, **kwargs: Any) -> str:
    kwargs.setdefault('default', serializer)
    kwargs.setdefault('ensure_ascii', False)
    return json.dumps(obj, **kwargs)

# black does not respect the fmt: off comment without this
# fmt: on
