# -*- coding: utf-8 -*-
# code generated by Prisma. DO NOT EDIT.
# pylint: disable=all
# pyright: reportUnusedImport=false
# fmt: off

# global imports for type checking
import sys
import datetime
from typing import (
    TYPE_CHECKING,
    Optional,
    Iterable,
    Iterator,
    Mapping,
    Tuple,
    Union,
    List,
    Dict,
    Type,
    Any,
    Set,
    overload,
    cast,
)

if sys.version_info >= (3, 8):
    from typing import TypedDict, Literal
else:
    from typing_extensions import TypedDict, Literal

# -- template engine/query.py.jinja --

import os
import time
import atexit
import signal
import asyncio
import logging
import subprocess
from pathlib import Path

from . import utils, errors
from ..http import HTTP
from ..utils import DEBUG
from .._types import Method
from ..binaries import platform
from ..utils import time_since, _env_bool


__all__ = ('QueryEngine',)

log: logging.Logger = logging.getLogger(__name__)


class QueryEngine:
    dml: str
    session: HTTP

    def __init__(self, *, dml: str, log_queries: bool = False):
        self.dml = dml
        self.session = HTTP()
        self._log_queries = log_queries
        self.url = None  # type: Optional[str]
        self.process = None  # type: Optional[subprocess.Popen[bytes]]
        self.file = None  # type: Optional[Path]

        # ensure the query engine process is terminated when we are
        atexit.register(self.stop)

    def __del__(self) -> None:
        self.stop()

    def stop(self) -> None:
        self.disconnect()
        try:
            loop = asyncio.get_event_loop()
        except RuntimeError:
            # no event loop in the current thread, we cannot cleanup
            return
        else:
            if not loop.is_closed():
                loop.create_task(self.close_session())

    def disconnect(self) -> None:
        log.debug('Disconnecting query engine...')

        if self.process is not None:
            if platform.name() == 'windows':
                self.process.kill()
            else:
                self.process.send_signal(signal.SIGINT)

            self.process.wait()
            self.process = None

        log.debug('Disconnected query engine')

    async def close_session(self) -> None:
        if self.session and not self.session.closed:
            await self.session.close()

    async def connect(self, timeout: int = 10) -> None:
        log.debug('Connecting to query engine')
        if self.process is not None:
            raise errors.AlreadyConnectedError('Already connected to the query engine')

        start = time.monotonic()
        self.file = file = utils.ensure()

        try:
            await self.spawn(file, timeout=timeout)
        except Exception:
            self.disconnect()
            raise

        log.debug('Connecting to query engine took %s', time_since(start))

    async def spawn(self, file: Path, timeout: int = 10) -> None:
        port = utils.get_open_port()
        log.debug('Running query engine on port %i', port)

        self.url = f'http://localhost:{port}'

        env = os.environ.copy()
        env.update(
            PRISMA_DML=self.dml,
            RUST_LOG='error',
            RUST_LOG_FORMAT='json',
        )

        if DEBUG:
            env.update(RUST_LOG='info')

        # TODO: remove the noise from these query logs
        if self._log_queries:
            env.update(LOG_QUERIES='y')

        args: List[str] = [str(file.absolute()), '-p', str(port), '--enable-raw-queries']
        if _env_bool('__PRISMA_PY_PLAYGROUND'):
            env.update(RUST_LOG='info')
            args.append('--enable-playground')

        log.debug('Starting query engine...')
        self.process = subprocess.Popen(
            args,
            env=env,
            stdout=sys.stdout,
            stderr=sys.stderr,
        )

        last_exc = None
        for _ in range(int(timeout / 0.1)):
            try:
                data = await self.request('GET', '/status')
            except Exception as exc:  # pylint: disable=broad-except
                last_exc = exc
                log.debug(
                    'Could not connect to query engine due to %s; retrying...',
                    type(exc).__name__,
                )
                await asyncio.sleep(0.1)

                continue

            if data.get('Errors') is not None:
                log.debug('Could not connect due to gql errors; retrying...')
                await asyncio.sleep(0.1)

                continue

            break
        else:
            raise errors.EngineConnectionError(
                'Could not connect to the query engine'
            ) from last_exc

    async def request(self, method: Method, path: str, *, data: Any = None) -> Any:
        if self.url is None:
            raise errors.NotConnectedError('Not connected to the query engine')

        kwargs = {
            'headers': {
                'Content-Type': 'application/json',
                'Accept': 'application/json',
            }
        }

        if data is not None:
            kwargs['data'] = data

        url = self.url + path
        log.debug('Sending %s request to %s with data: %s', method, url, data)

        resp = await self.session.request(method, url, **kwargs)

        if 300 > resp.status >= 200:
            response = await resp.json()
            log.debug('%s %s returned %s', method, url, response)

            errors_data = response.get('errors')
            if errors_data:
                return utils.handle_response_errors(resp, errors_data)

            return response

        if resp.status == 422:
            raise errors.UnprocessableEntityError(resp)

        # TODO: handle errors better
        raise errors.EngineRequestError(resp, await resp.text())


# black does not respect the fmt: off comment without this
# fmt: on
