import requests
import orjson
import yaml
from io import TextIOWrapper

from netlas.exception import APIError
from netlas.helpers import check_status_code


class Netlas:
    def __init__(self,
                 api_key: str = "",
                 apibase: str = "https://app.netlas.io",
                 debug: bool = False) -> None:
        """Netlas class constructor

        :param api_key: Personal API key, defaults to ""
        :type api_key: str
        :param apibase: Netlas API server address, defaults to "https://app.netlas.io"
        :type apibase: str, optional
        :param debug: Debug flag, defaults to False
        :type debug: bool, optional
        """
        self.api_key: str = api_key
        self.apibase: str = apibase.rstrip("/")
        self.debug: bool = debug
        self.verify_ssl: bool = True
        if self.apibase != "https://app.netlas.io":
            self.verify_ssl = False
        if not self.api_key:
            raise APIError({"error": "API key is empty"})
        self.headers = {
            'Content-Type': 'application/json',
            'X-Api-Key': self.api_key
        }

    def _request(self, endpoint: str = "/api/", params: object = {}) -> dict:
        """Private requests wrapper.
        Sends a request to Netlas API endpoint and process result.

        :param endpoint: API endpoint, defaults to "/api/"
        :type endpoint: str
        :param params: GET parameters for request, defaults to {}
        :type params: object, optional
        :raises APIError: Failed to parse JSON response
        :raises APIError: Other HTTP error
        :return: parsed JSON response
        :rtype: dict
        """
        ret: dict = {}
        try:
            if not self.api_key:
                ret["error"] = "API key is empty"
                raise APIError(ret['error'])

            r = requests.get(f"{self.apibase}{endpoint}",
                             params=params,
                             headers=self.headers,
                             verify=self.verify_ssl)
            response_data = orjson.loads(r.text)
        except orjson.JSONDecodeError:
            ret["error"] = "Failed to parse response data to JSON"
            if self.debug:
                ret["error_description"] = r.reason
                ret["error_data"] = r.text
        except requests.HTTPError:
            ret["error"] = f"{r.status_code}: {r.reason}"
            if self.debug:
                ret["error_description"] = r.reason
                ret["error_data"] = r.text

        if ret.get('error', None):
            raise APIError(ret['error'])
        check_status_code(request=r, debug=self.debug, ret=ret)

        ret = response_data
        return ret

    def _stream_request(self,
                        endpoint: str = "/api/",
                        params: object = {}) -> bytes:
        """Private stream requests wrapper.
        Sends a request to Netlas API endpoint and yield data from stream.

        :param endpoint: API endpoint, defaults to "/api/"
        :type endpoint: str
        :param params: GET parameters for request, defaults to {}
        :type params: object, optional
        :raises APIError: Failed to parse JSON response
        :raises APIError: Other HTTP error
        :return: Iterator of raw bytes from response
        :rtype: Iterator[bytes]
        """
        ret: dict = {}
        if not self.api_key:
            ret["error"] = "API key is empty"
            raise APIError(ret['error'])
        try:
            with requests.get(f"{self.apibase}{endpoint}",
                              params=params,
                              headers=self.headers,
                              verify=self.verify_ssl,
                              stream=True) as r:
                check_status_code(request=r, debug=self.debug, ret=ret)
                for chunk in r.iter_content(chunk_size=2048):
                    #skip keep-alive chunks
                    if chunk:
                        yield chunk
        except requests.HTTPError:
            ret["error"] = f"{r.status_code}: {r.reason}"
            if self.debug:
                ret["error_description"] = r.reason
                ret["error_data"] = r.text
            raise APIError(ret['error'])

    def query(self,
              query: str,
              datatype: str = "response",
              indices: str = "") -> dict:
        """Send search query to Netlas API

        :param query: Search query string
        :type query: str
        :param datatype: Data type (choises: response, cert, domain), defaults to "response"
        :type datatype: str, optional
        :param indices: Comma-separated IDs of selected data indices (can be retrieved by `indices` method), defaults to ""
        :type indices: str, optional
        :return: search query result
        :rtype: dict
        """
        endpoint = "/api/responses/"
        if datatype == "cert":
            endpoint = "/api/certs/"
        elif datatype == "domain":
            endpoint = "/api/domains/"
        ret = self._request(
            endpoint=endpoint,
            params={
                "q": query,
                "indices": indices
            },
        )
        return ret

    def count(self,
              query: str,
              datatype: str = "response",
              indices: str = "") -> dict:
        """Calculate total count of query string results

        :param query: Search query string
        :type query: str
        :param datatype: Data type (choises: response, cert, domain), defaults to "response"
        :type datatype: str, optional
        :param indices: Comma-separated IDs of selected data indices (can be retrieved by `indices` method), defaults to ""
        :type indices: str, optional
        :return: JSON object with total count of query string results
        :rtype: dict
        """
        endpoint = "/api/responses_count/"
        if datatype == "cert":
            endpoint = "/api/certs_count/"
        elif datatype == "domain":
            endpoint = "/api/domains_count/"
        ret = self._request(endpoint=endpoint,
                            params={
                                "q": query,
                                "indices": indices
                            })
        return ret

    def stat(self, query: str, indices: str = "") -> dict:
        """Get statistics of responses query string results

        :param query: Search query string
        :type query: str
        :param indices: Comma-separated IDs of selected data indices (can be retrieved by `indices` method), defaults to ""
        :type indices: str, optional
        :return: JSON object with statistics of responses query string results
        :rtype: dict
        """
        ret = self._request(
            endpoint="/api/responses_stat/",
            params={
                "q": query,
                "indices": indices
            },
        )
        return ret

    def profile(self) -> dict:
        """Get user profile data

        :return: JSON object with user profile data
        :rtype: dict
        """
        endpoint = "/api/users/profile/"
        ret = self._request(endpoint=endpoint)
        return ret

    def host(self, host: str, hosttype: str = "ip", index: str = "") -> dict:
        """Get full information about host (ip or domain)

        :param host: IP or domain string
        :type host: str
        :param hosttype: `"ip"` or `"domain"`, defaults to "ip"
        :type hosttype: str, optional
        :param index: ID of selected data indices (can be retrieved by `indices` method), defaults to ""
        :type index: str, optional
        :return: JSON object with full information about host
        :rtype: dict
        """
        endpoint = "/api/ip/"
        if hosttype == "domain":
            endpoint = "/api/domain/"
        ret = self._request(
            endpoint=endpoint,
            params={
                "q": host,
                "index": index
            },
        )
        return ret

    def download(self,
                 query: str,
                 datatype: str = "response",
                 size: int = 10,
                 indices: str = "") -> bytes:
        """Download data from Netlas

        :param query: Search query string
        :type query: str
        :param datatype: Data type (choises: response, cert, domain), defaults to "response"
        :type datatype: str, optional
        :param size: Download documents count, defaults to 10
        :type size: int, optional
        :param indices: Comma-separated IDs of selected data indices (can be retrieved by `indices` method), defaults to ""
        :type indices: str, optional
        :return: Iterator of raw data
        :rtype: Iterator[bytes]
        """
        endpoint = "/api/responses/download/"
        if datatype == "cert":
            endpoint = "/api/certs/download/"
        elif datatype == "domain":
            endpoint = "/api/domains/download/"
        for ret in self._stream_request(
                endpoint=endpoint,
                params={
                    "q": query,
                    "size": size,
                    "indices": indices
                },
        ):
            yield ret

    def indices(self) -> list:
        """Get available data indices

        :return: List of available indices
        :rtype: list
        """
        endpoint = "/api/indices/"
        ret = self._request(endpoint=endpoint)
        return ret