from collections.abc import Mapping, MutableMapping
from collections import OrderedDict
from urllib.parse import urlsplit
from .model import Response
from .exceptions import *
import socket
import brotli
import socks
import zlib
import json
import gzip
import ssl


# https://github.com/psf/requests/blob/f6d43b03fbb9a1e75ed63a9aa15738a8fce99b50/requests/structures.py#L15
scheme_to_proxy_type = {
    "http": socks.HTTP,
    "https": socks.HTTP,
    "socks": socks.SOCKS4,
    "socks5": socks.SOCKS5,
    "socks5h": socks.SOCKS5,
}

scheme_to_port = {"http": 80, "https": 443}


class CaseInsensitiveDict(MutableMapping):
    """A case-insensitive ``dict``-like object.
    Implements all methods and operations of
    ``MutableMapping`` as well as dict's ``copy``. Also
    provides ``lower_items``.
    All keys are expected to be strings. The structure remembers the
    case of the last key to be set, and ``iter(instance)``,
    ``keys()``, ``items()``, ``iterkeys()``, and ``iteritems()``
    will contain case-sensitive keys. However, querying and contains
    testing is case insensitive::
        cid = CaseInsensitiveDict()
        cid['Accept'] = 'application/json'
        cid['aCCEPT'] == 'application/json'  # True
        list(cid) == ['Accept']  # True
    For example, ``headers['content-encoding']`` will return the
    value of a ``'Content-Encoding'`` response header, regardless
    of how the header name was originally stored.
    If the constructor, ``.update``, or equality comparison
    operations are given keys that have equal ``.lower()``s, the
    behavior is undefined.
    """

    def __init__(self, data=None, **kwargs):
        self._store = OrderedDict()
        if data is None:
            data = {}
        self.update(data, **kwargs)

    def __setitem__(self, key, value):
        # Use the lowercased key for lookups, but store the actual
        # key alongside the value.
        self._store[key.lower()] = (key, value)

    def __getitem__(self, key):
        return self._store[key.lower()][1]

    def __delitem__(self, key):
        del self._store[key.lower()]

    def __iter__(self):
        return (casedkey for casedkey, mappedvalue in self._store.values())

    def __len__(self):
        return len(self._store)

    def lower_items(self):
        """Like iteritems(), but with all lowercase keys."""
        return ((lowerkey, keyval[1]) for (lowerkey, keyval) in self._store.items())

    def __eq__(self, other):
        if isinstance(other, Mapping):
            other = CaseInsensitiveDict(other)
        else:
            return NotImplemented
        # Compare insensitively
        return dict(self.lower_items()) == dict(other.lower_items())

    # Copy is required
    def copy(self):
        return CaseInsensitiveDict(self._store.values())

    def __repr__(self):
        return str(dict(self.items()))


class ProxyError(Exception):
    """Raises an error when client proxy has not worked"""

    pass


class Session:
    def __init__(
        self,
        proxies=None,
        timeout=None,
        chunk_size=None,
        decode_content=None,
        verify=None,
    ):
        timeout = timeout if timeout is not None else 60
        chunk_size = chunk_size if chunk_size is not None else (1024**2)
        decode_content = decode_content if decode_content is not None else True
        verify = verify if verify is not None else True
        if proxies is None:
            proxies = {}
        else:
            for scheme, proxy_url in proxies.items():
                proxy = urlsplit(proxy_url)

                if scheme not in scheme_to_port:
                    raise UnsupportedScheme("'%s' is not a supported scheme" % (scheme))

                if proxy.scheme not in scheme_to_proxy_type:
                    raise UnsupportedScheme(
                        "'%s' is not a supported proxy scheme" % (proxy.scheme)
                    )

                proxies[scheme] = proxy

        self.timeout = timeout
        self.max_chunk_size = chunk_size
        self.decode_content = decode_content
        self.verify = verify
        self._scheme_to_proxy = proxies
        self._addr_to_conn = {}
        self._verified_context = ssl.create_default_context()

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.clear()

    def clear(self):
        addrs = list(self._addr_to_conn)
        while addrs:
            addr = addrs.pop()
            sock = self._addr_to_conn[addr]
            try:
                sock.shutdown(socket.SHUT_RDWR)
            except OSError:
                pass
            sock.close()
            self._addr_to_conn.pop(addr, None)

    def request(
        self,
        method,
        url,
        headers=None,
        data=None,
        timeout=None,
        verify=None,
        ciphers=None,
        version=None,
    ):
        parsed_url = urlsplit(url)
        if parsed_url.scheme not in scheme_to_port:
            raise UnsupportedScheme("'%s' is not a supported scheme" % (scheme))

        if verify is None:
            verify = self.verify

        if version is None:
            version = "1.1"

        if not isinstance(headers, CaseInsensitiveDict):
            headers = CaseInsensitiveDict(headers)

        if not "Host" in headers:
            headers["Host"] = parsed_url.hostname

        if data is not None:
            if not isinstance(data, bytes):
                data = data.encode("utf-8")

            if "Content-Length" not in headers:
                headers["Content-Length"] = int(len(data))

        host_addr = (
            parsed_url.hostname.lower(),
            parsed_url.port or scheme_to_port[parsed_url.scheme],
        )
        conn_reused = host_addr in self._addr_to_conn
        request = self._prepare_request(
            method=method,
            path=(
                parsed_url.path + ("?" + parsed_url.query if parsed_url.query else "")
            )
            or "/",
            version=version,
            headers=headers,
            body=data,
        )

        while True:
            try:
                conn = self._addr_to_conn.get(host_addr)
                if conn is None:
                    conn = self._create_socket(
                        host_addr,
                        proxy=self._scheme_to_proxy.get(parsed_url.scheme),
                        timeout=timeout if timeout is not None else self.timeout,
                        ssl_wrap=("https" == parsed_url.scheme),
                        ssl_verify=verify,
                        ciphers=ciphers,
                    )
                    self._addr_to_conn[host_addr] = conn
                else:
                    if timeout is not None:
                        conn.settimeout(timeout)

                conn.send(request)
                return self._get_response(
                    conn, self.max_chunk_size, self.decode_content
                )

            except Exception as err:
                if host_addr in self._addr_to_conn:
                    self._addr_to_conn.pop(host_addr)

                if not conn_reused:
                    if not isinstance(err, RequestException):
                        err = RequestException(err)
                    raise err

                conn_reused = False

    def get(self, url, **kwargs):
        return self.request("GET", url, **kwargs)

    def post(self, url, **kwargs):
        return self.request("POST", url, **kwargs)

    def options(self, url, **kwargs):
        return self.request("OPTIONS", url, **kwargs)

    def head(self, url, **kwargs):
        return self.request("HEAD", url, **kwargs)

    def put(self, url, **kwargs):
        return self.request("PUT", url, **kwargs)

    def patch(self, url, **kwargs):
        return self.request("PATCH", url, **kwargs)

    def delete(self, url, **kwargs):
        return self.request("DELETE", url, **kwargs)

    def _create_socket(
        self,
        dest_addr,
        proxy=None,
        timeout=None,
        ssl_wrap=True,
        ssl_verify=True,
        remote_dns=False,
        ciphers=None,
    ):
        if proxy is None:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        else:
            sock = socks.socksocket()
            sock.set_proxy(
                scheme_to_proxy_type[proxy.scheme],
                addr=proxy.hostname,
                port=proxy.port,
                username=proxy.username,
                password=proxy.password,
                rdns=remote_dns,
            )

        if timeout:
            sock.settimeout(timeout)

        sock.connect(dest_addr)

        if ssl_wrap:
            context = self._verified_context if ssl_verify else self._unverified_context
            if ciphers is not None:
                context.set_ciphers(ciphers)
            sock = context.wrap_socket(sock, server_hostname=dest_addr[0])

        return sock

    @staticmethod
    def _prepare_request(method, path, version, headers, body):
        request = "%s %s HTTP/%s\r\n" % (method, path, version)

        for header, value in headers.items():
            if value is None:
                continue
            request += "%s: %s\r\n" % (header, value)

        request += "\r\n"
        request = request.encode("UTF-8")

        if body is not None:
            request += body

        return request

    @staticmethod
    def _get_response(conn, max_chunk_size, decode_content):
        resp = conn.recv(max_chunk_size)

        if len(resp) == 0:
            raise EmptyResponse("Empty response from server")

        resp, data = resp.split(b"\r\n\r\n", 1)
        resp = resp.decode()
        status, raw_headers = resp.split("\r\n", 1)
        version, status, message = status.split(" ", 2)

        headers = CaseInsensitiveDict()
        for header in raw_headers.splitlines():
            header, value = header.split(":", 1)
            value = value.lstrip(" ")
            if header in headers:
                if isinstance(headers[header], str):
                    headers[header] = [headers[header]]
                headers[header].append(value)
            else:
                headers[header] = value

        # download chunks until content-length is met
        if "content-length" in headers:
            goal = int(headers["content-length"])
            while goal > len(data):
                chunk = conn.recv(min(goal - len(data), max_chunk_size))
                if len(chunk) == 0:
                    raise RequestException("Empty chunk")
                data += chunk

        # download chunks until "0\r\n\r\n" is recv'd, then process them
        elif headers.get("transfer-encoding") == "chunked":
            while True:
                chunk = conn.recv(max_chunk_size)
                if len(chunk) == 0 or chunk == b"0\r\n\r\n":
                    break
                data += chunk

            raw = data
            data = b""
            while raw:
                length, raw = raw.split(b"\r\n", 1)
                length = int(length, 16)
                chunk, raw = raw[:length], raw[length + 2 :]
                data += chunk

        # download chunks until recv is empty
        else:
            while True:
                chunk = conn.recv(max_chunk_size)
                if len(chunk) == 0:
                    break
                data += chunk

        if "content-encoding" in headers and decode_content:
            data = self._decode_content(data, headers["content-encoding"])

        return Response(int(status), message, headers, data)

    @staticmethod
    def _decode_content(content, encoding):
        if encoding == "br":
            content = brotli.decompress(content)
        elif encoding == "gzip":
            content = gzip.decompress(content)
        elif encoding == "deflate":
            content = zlib.decompress(content)
        else:
            raise UnsupportedEncoding(
                "Unknown encoding type '%s' while decoding content" % (encoding)
            )

        return content
