from threading import Thread, Lock, Event
from queue import Queue, Empty
from select import select
from socket import socket, AF_INET, AF_INET6
from enum import IntEnum
from typing import List, Union, Optional, Callable, Type
from ecdsa.keys import SigningKey, VerifyingKey
from ecdsa.curves import SECP256k1
from Cryptodome.Cipher import AES
from Cryptodome.Cipher._mode_gcm import GcmMode
from Cryptodome.Random import get_random_bytes
from binascii import a2b_hex
from hashlib import sha256
from time import time, sleep
from ipaddress import IPv4Address, IPv6Address
import logging
import json
import zlib


def uniq_id_generator() -> Callable[[], int]:
    """uniq Sock's id generator"""
    lock = Lock()
    base_uuid = 0

    def _get_uuid() -> int:
        nonlocal base_uuid
        with lock:
            base_uuid += 1
            return base_uuid
    return _get_uuid


get_uuid = uniq_id_generator()
CallbackFnc = Callable[[bytes, 'Sock'], None]
log = logging.getLogger(__name__)


class SockControl(IntEnum):
    """sock control bits"""
    ENCRYPTED = 0b00000001
    COMPRESSED = 0b00000010
    VALIDATED = 0b00000100
    INNER_WORK = 0b00001000
    WARNING = 0b00010000
    CLOSING = 0b10000000


class SockType(IntEnum):
    """sock types"""
    SERVER = 0  # server: wait for new connection
    INBOUND = 1  # client: incoming connection generated by server
    OUTBOUND = 2  # client: outgoing connection made by us


# shared key's curve
CURVE = SECP256k1


class UnrecoverableClose(Exception):
    """data processing has collapsed and socket is closed"""


class GracefulClose(Exception):
    """signaled by the other 'we are closing now'"""


class JobCorruptedWarning(Warning):
    """a job was damaged in the process and socket is returned to listen state"""


class Sock(object):
    """client socket object wrapper and callback"""

    def __init__(
            self,
            sock: socket,
            callback: CallbackFnc,
            stype: SockType,
            others_key: Optional[VerifyingKey],
            secret_key: SigningKey,
    ):
        assert sock.gettimeout() is None, "only non-blocking mode"
        assert sock.family in (AF_INET, AF_INET6)
        self.id = get_uuid()
        self.sock = sock
        self.lock = Lock()
        self.callback = callback
        # note: check inner work is processing by `inner_event is not None`
        # note: wait other inner work by `inner_event.wait(sec)`
        self.inner_que: Optional[Queue[bytes]] = None
        self.inner_event: Optional[Event] = None
        self.stype = stype
        self.tmp = bytearray()  # recv buffer
        self.common_key: Optional[bytes] = None  # 32bytes AES key
        self.others_key = others_key  # others public key
        self.secret_key = secret_key  # my secret key
        self.flags = 0b00000000  # default no flag
        self.delay = 0.0

    def __repr__(self) -> str:
        host = "closed" if self.sock.fileno() == -1 else str(self.get_opposite_host())
        flag = "flag=" + \
                ("E" if self.flags & SockControl.ENCRYPTED else "_") + \
                ("C" if self.flags & SockControl.COMPRESSED else "_") + \
                ("V" if self.flags & SockControl.VALIDATED else "_")
        family = getattr(self.sock.family, "name", str(self.sock.family))
        class_origin = str(self.sock.__class__.__name__)
        return f"<Sock {host} id={self.id} {flag} {self.stype.name} {family} {class_origin}>"

    @staticmethod
    def _encrypt(key: bytes, data: bytes) -> bytes:
        """encrypt by AES-GCM (more secure than CBC mode)"""
        cipher: 'GcmMode' = AES.new(key, AES.MODE_GCM)  # type: ignore
        # warning: Don't reuse nonce
        enc, tag = cipher.encrypt_and_digest(data)
        # output length = 16bytes + 16bytes + N(=data)bytes
        return cipher.nonce + tag + enc

    @staticmethod
    def _decrypt(key: bytes, data: bytes) -> bytes:
        """decrypt by AES-GCM (more secure than CBC mode)"""
        cipher: 'GcmMode' = AES.new(key, AES.MODE_GCM, nonce=data[:16])  # type: ignore
        # ValueError raised when verify failed
        try:
            return cipher.decrypt_and_verify(data[32:], data[16:32])
        except ValueError:
            pass
        raise UnrecoverableClose("decryption failed")

    def recv(self) -> None:
        """push new raw bytes and find new chunk msg"""
        assert self.stype is not SockType.SERVER

        data = self.sock.recv(4096)
        if len(data) == 0:
            raise UnrecoverableClose("receive zero message (socket is closed)")
        self.tmp.extend(data)

        flag = self.tmp[0]  # 1byte
        length = int.from_bytes(self.tmp[1:5], "little")  # 4bytes

        if len(self.tmp) < 5 + length:
            return  # not found new msg
        else:
            # find new chunked msg
            msg = bytes(self.tmp[5:5 + length])
            self.tmp = self.tmp[5 + length:]

            # decrypt
            if flag & SockControl.ENCRYPTED:
                if self.common_key is None:
                    # note: block max 1sec
                    if self.inner_event is None:
                        raise UnrecoverableClose("try to decrypt msg but not found key (1)")
                    elif self.inner_event.wait(1.0):
                        if self.common_key is None:
                            raise UnrecoverableClose("try to decrypt msg but not found key (2)")
                    else:
                        raise UnrecoverableClose("try to decrypt msg but not found key (3)")
                msg = self._decrypt(self.common_key, msg)

            # decompress
            if flag & SockControl.COMPRESSED:
                msg = zlib.decompress(msg)

            # execute callback with same thread
            # warning: don't block because a listen thread fall in dysfunctional
            if flag & SockControl.INNER_WORK:
                if self.inner_que is None:
                    # start inner work thread
                    self._ignite_inner_work(msg)
                else:
                    self.inner_que.put(msg)
            elif flag & SockControl.WARNING:
                # just warning log by the other
                log.warning(f"warning {self}: {msg.decode(errors='ignore')}")
            elif flag & SockControl.CLOSING:
                # closing request
                raise GracefulClose(f"closing request: {msg.decode(errors='ignore')}")
            else:
                # is outer work
                self.callback(msg, self)
            return

    def sendall(self, msg: bytes, flag: int = None) -> None:
        """send chunked message"""
        assert self.stype is not SockType.SERVER

        # manual setup flag
        if flag is None:
            flag = self.flags

        # convert
        if flag & SockControl.COMPRESSED:
            msg = zlib.compress(msg)
        if flag & SockControl.ENCRYPTED:
            if self.common_key is None:
                raise JobCorruptedWarning("try to encrypt msg but not found key")
            else:
                msg = self._encrypt(self.common_key, msg)

        with self.lock:
            # convert: [flag 1b][len 4b][msg xb]
            self.sock.sendall(flag.to_bytes(1, "little") + len(msg).to_bytes(4, "little") + msg)

    def _ignite_inner_work(self, msg: bytes) -> None:
        """start inner work thread by F_INNER_WORK"""
        assert self.inner_que is None
        obj: dict = json.loads(msg.decode())

        if "ignite" not in obj:
            raise JobCorruptedWarning("no ignite flag and no inner work")

        elif obj["ignite"] == "encryption":
            self.establish_encryption(False)
            assert self.inner_que is not None
            self.inner_que.put(msg)

        elif obj["ignite"] == "validation":
            self.validate_the_other(False)
            assert self.inner_que is not None
            self.inner_que.put(msg)

        elif obj["ignite"] == "measure-delay":
            self.measure_delay_time(False)
            assert self.inner_que is not None
            self.inner_que.put(msg)

        else:
            raise NotImplementedError(f"unknown ignite flag: {obj}")

    def establish_encryption(self, is_primary: bool) -> Event:
        """establish encrypted connection"""
        assert self.common_key is None, "already established encryption (key)"
        assert not (self.flags & SockControl.ENCRYPTED), "already established encryption (flag)"
        now = time()

        # wait for other inner work finish
        if is_primary:
            if self.inner_event is not None and self.inner_event.wait(5.0) is False:
                raise AssertionError("inner work is processing now")
        else:
            assert self.inner_event is None

        recv_que: Queue[bytes] = Queue()
        finished_event = Event()
        finished_event.clear()

        def send_json(obj: dict) -> None:
            """send inner work msg by plain"""
            self.sendall(json.dumps(obj).encode() + b"\n", SockControl.INNER_WORK)

        def recv_json() -> dict:
            """recv inner work msg or raise queue.Empty"""
            return json.loads(recv_que.get(True, 5.0).decode())

        def close_inner() -> None:
            """close inner process"""
            self.inner_que = None
            self.inner_event = None
            finished_event.set()

        def inner_work() -> None:
            log.debug("start establish_encryption()")
            common_key: Optional[bytes] = None
            shared_key: Optional[bytes] = None
            established_msg = b"success establish encrypted connection"
            my_pk: VerifyingKey

            try:
                # first stage
                if is_primary:
                    # (primary): generate pair, send pk & curve
                    my_sk = SigningKey.generate(CURVE)
                    my_pk = my_sk.get_verifying_key()
                    send_json({
                        "ignite": "encryption",
                        "curve": str(CURVE.name),
                        "public-key": my_pk.to_string().hex()
                    })
                else:
                    # (standby): receive pk, calc sharedKey, send
                    my_sk = SigningKey.generate(CURVE)
                    my_pk = my_sk.get_verifying_key()
                    obj = recv_json()
                    assert str(obj.get("curve")) == CURVE.name, "{} is not same curve".find(obj.get("curve"))
                    assert "public-key" in obj
                    other_pk = VerifyingKey.from_string(a2b_hex(obj["public-key"]), curve=CURVE)
                    shared_point = my_sk.privkey.secret_multiplier * other_pk.pubkey.point
                    shared_key = sha256(shared_point.x().to_bytes(32, 'big')).digest()
                    common_key = get_random_bytes(32)
                    encrypted_key = self._encrypt(shared_key, common_key)
                    send_json({
                        "public-key": my_pk.to_string().hex(),
                        "encrypted-key": encrypted_key.hex()
                    })

                # second stage
                if is_primary:
                    # (primary): receive pk & encryptedKey, calc sharedKey
                    # decrypt commonKey, send hell msg
                    obj = recv_json()
                    assert "public-key" in obj
                    other_pk = VerifyingKey.from_string(a2b_hex(obj["public-key"]), curve=CURVE)
                    shared_point = my_sk.privkey.secret_multiplier * other_pk.pubkey.point
                    shared_key = sha256(shared_point.x().to_bytes(32, 'big')).digest()
                    assert "encrypted-key" in obj
                    common_key = self._decrypt(shared_key, a2b_hex(obj["encrypted-key"]))
                    encrypted_msg = self._encrypt(shared_key, established_msg)
                    send_json({
                        "encrypted-msg": encrypted_msg.hex()
                    })
                else:
                    # (standby): receive encrypted msg, confirm connection by decrypting
                    obj = recv_json()
                    assert "encrypted-msg" in obj
                    assert isinstance(shared_key, bytes)
                    msg = self._decrypt(shared_key, a2b_hex(obj["encrypted-msg"]))
                    assert msg == established_msg

            except Empty:
                log.debug("queue is empty on establish_encryption()")
            except AssertionError as e:
                log.info("AssertionError %s", e)
                log.debug(self, exc_info=True)
            except Exception:
                log.warning("Exception", exc_info=True)
            else:
                # success
                with self.lock:
                    self.common_key = common_key
                    self.flags |= SockControl.ENCRYPTED
            finally:
                close_inner()
            log.debug("finish establish_encryption() %fs", time() - now)

        # start background thread
        self.inner_que = recv_que
        self.inner_event = finished_event
        Thread(target=inner_work, name="InnerEnc").start()
        return finished_event

    def validate_the_other(self, is_primary: bool) -> Event:
        """validate the other by publicKey"""
        now = time()

        # wait for other inner work finish
        if is_primary:
            if self.inner_event is not None and self.inner_event.wait(5.0) is False:
                raise AssertionError("inner work is processing now")
        else:
            assert self.inner_event is None

        recv_que: Queue[bytes] = Queue()
        finished_event = Event()
        finished_event.clear()

        def send_json(obj: dict) -> None:
            """send inner work msg by plain"""
            self.sendall(json.dumps(obj).encode() + b"\n", SockControl.INNER_WORK)

        def recv_json() -> dict:
            """recv inner work msg or raise queue.Empty"""
            return json.loads(recv_que.get(True, 5.0).decode())

        def close_inner() -> None:
            """close inner process"""
            self.inner_que = None
            self.inner_event = None
            finished_event.set()

        def inner_work() -> None:
            log.debug("start establish_encryption()")
            random_key = get_random_bytes(32)
            another_key: Optional[bytes] = None
            others_key: Optional[VerifyingKey] = self.others_key
            signature: bytes

            try:
                # first stage
                if is_primary:
                    send_json({
                        "ignite": "validation",
                        "curve": str(CURVE.name),
                        "another-key": random_key.hex(),
                    })
                else:
                    obj = recv_json()
                    assert str(obj.get("curve")) == CURVE.name, "{} is not same curve".find(obj.get("curve"))
                    assert "another-key" in obj
                    another_key = a2b_hex(obj["another-key"])
                    assert len(another_key) == 32
                    # sign
                    signature = self.secret_key.sign(random_key + another_key, hashfunc=sha256)
                    my_public_key = self.secret_key.get_verifying_key()
                    send_json({
                        "another-key": random_key.hex(),
                        "public-key": my_public_key.to_string().hex(),
                        "signature": signature.hex()
                    })

                # second stage
                if is_primary:
                    obj = recv_json()
                    assert "another-key" in obj
                    another_key = a2b_hex(obj["another-key"])
                    assert len(another_key) == 32
                    assert "public-key" in obj
                    others_public_key = a2b_hex(obj["public-key"])
                    assert "signature" in obj
                    others_signature = a2b_hex(obj["signature"])
                    # verify
                    if others_key is None:
                        others_key = VerifyingKey.from_string(others_public_key, curve=CURVE)
                    else:
                        assert others_public_key == others_key.to_string(), "don't match public key"
                    others_key.verify(others_signature, another_key + random_key, hashfunc=sha256)
                    # sign
                    signature = self.secret_key.sign(random_key + another_key, hashfunc=sha256)
                    my_public_key = self.secret_key.get_verifying_key()
                    send_json({
                        "public-key": my_public_key.to_string().hex(),
                        "signature": signature.hex(),
                    })
                else:
                    obj = recv_json()
                    assert "public-key" in obj
                    others_public_key = a2b_hex(obj["public-key"])
                    assert "signature" in obj
                    others_signature = a2b_hex(obj["signature"])
                    # verify
                    if others_key is None:
                        others_key = VerifyingKey.from_string(others_public_key, curve=CURVE)
                    else:
                        assert others_public_key == others_key.to_string(), "don't match public key"
                    others_key.verify(others_signature, another_key + random_key, hashfunc=sha256)

            except Empty:
                log.debug("queue is empty on validate_the_other()")
            except AssertionError as e:
                log.info("AssertionError %s", e)
                log.debug(self, exc_info=True)
            except Exception:
                log.warning("Exception", exc_info=True)
            else:
                # success
                with self.lock:
                    self.flags |= SockControl.VALIDATED
                    self.others_key = others_key
            finally:
                close_inner()
            log.debug("finish establish_encryption() %fs", time() - now)

        # start background thread
        self.inner_que = recv_que
        self.inner_event = finished_event
        Thread(target=inner_work, name="InnerVal").start()
        return finished_event

    def measure_delay_time(self, is_primary: bool) -> Event:
        """check delay time by ping"""
        now = time()

        # wait for other inner work finish
        if is_primary:
            if self.inner_event is not None and self.inner_event.wait(5.0) is False:
                raise AssertionError("inner work is processing now")
        else:
            assert self.inner_event is None

        recv_que: Queue[bytes] = Queue()
        finished_event = Event()
        finished_event.clear()

        def send_json(obj: dict) -> None:
            """send inner work msg by plain"""
            self.sendall(json.dumps(obj).encode() + b"\n", SockControl.INNER_WORK)

        def recv_json() -> dict:
            """recv inner work msg or raise queue.Empty"""
            return json.loads(recv_que.get(True, 5.0).decode())

        def close_inner() -> None:
            """close inner process"""
            self.inner_que = None
            self.inner_event = None
            finished_event.set()

        def inner_work() -> None:
            log.debug("start measure_delay_time()")
            my_now: Optional[float] = time()
            others_now: Optional[float] = None

            try:
                if is_primary:
                    my_now = time()
                    send_json({
                        "ignite": "measure-delay",
                        "now": my_now,
                    })
                else:
                    obj = recv_json()
                    assert "now" in obj
                    others_now = obj["now"]
                    my_now = time()
                    send_json({
                        "now": my_now,
                    })

                if is_primary:
                    obj = recv_json()
                    assert "now" in obj
                    others_now = obj["now"]
                else:
                    pass

            except Empty:
                log.debug("queue is empty on measure_delay_time()")
            except AssertionError as e:
                log.info("AssertionError %s", e)
                log.debug(self, exc_info=True)
            except Exception:
                log.warning("Exception", exc_info=True)
            else:
                # success
                with self.lock:
                    self.delay = others_now - my_now
            finally:
                close_inner()
            log.debug("finish measure_delay_time() %fs", time() - now)

        # start background thread
        self.inner_que = recv_que
        self.inner_event = finished_event
        Thread(target=inner_work, name="InnerDly").start()
        return finished_event

    def get_opposite_host(self) -> Union[IPv4Address, IPv6Address]:
        """get the other end of the connection host name"""
        family: Union[Type[IPv4Address], Type[IPv6Address]]

        if self.sock.family is AF_INET:
            family = IPv4Address
        elif self.sock.family is AF_INET6:
            family = IPv6Address
        else:
            raise NotImplementedError(f"not found family: {self.sock.family}")

        # note: raise OSError if socket is closed
        if self.stype is SockType.SERVER:
            return family(self.sock.getsockname()[0])
        elif self.stype is SockType.INBOUND:
            return family(self.sock.getpeername()[0])
        elif self.stype is SockType.OUTBOUND:
            return family(self.sock.getpeername()[0])
        else:
            raise NotImplementedError(f"not found stype: {self.stype}")

    def fileno(self) -> int:
        """is for select poll"""
        return self.sock.fileno()

    def close(self) -> None:
        """note: use pool's close_sock() instead"""
        if self.inner_event is not None:
            self.inner_event.wait()
        with self.lock:
            self.sock.close()


class SockPool(Thread):
    def __init__(self, name: str = 'SockPool', secret: bytes = None) -> None:
        super().__init__(name=name)
        self.socks: List[Sock] = list()
        self.lock = Lock()
        if secret is None:
            self.secret_key = SigningKey.generate(CURVE)
        else:
            self.secret_key = SigningKey.from_string(secret, CURVE)
        # status
        self.running = Event()  # main thread is working
        self.closing = False  # is closing now
        self.closed = False  # class is closed
        # init
        self.running.set()

    def __repr__(self) -> str:
        pubkey = self.secret_key.get_verifying_key().to_string().hex()
        if self.closed:
            status = "closed"
        elif self.closing:
            status = "closing"
        elif not self.running.is_set():
            status = "running"
        else:
            status = "waiting"
        return f"<SockPool pub={pubkey[:10]}..{pubkey[-10:]} status={status} len={len(self.socks)}>"

    def run(self) -> None:
        assert self.running.is_set(), "already running main thread"

        self.running.clear()

        # listen sockets
        while not self.closing:
            # note: only socket object is supported
            with self.lock:
                if len(self.socks) == 0:
                    sleep(0.2)
                    continue
                else:
                    r, _w, _x = select(self.socks, [], [], 0.2)

            # socket recv() or accept()
            for sock in r:
                if sock.stype is SockType.SERVER:
                    raw_sock, _addr = sock.sock.accept()
                    raw_sock.settimeout(None)
                    new_sock = Sock(raw_sock, sock.callback, SockType.INBOUND, None, self.secret_key)
                    with self.lock:
                        self.socks.append(new_sock)
                    log.debug("accept %s", new_sock)

                else:
                    try:
                        sock.recv()
                    except NotImplementedError as e:
                        log.info("NotImplementedError %s %s", sock, e)
                    except JobCorruptedWarning as e:
                        log.debug("JobCorruptedWarning %s %s", sock, e)
                    except GracefulClose as e:
                        log.debug("GracefulClose %s", e)
                        self.close_sock(sock)
                    except UnrecoverableClose as e:
                        log.info("UnrecoverableClose %s", sock)
                        log.debug(str(e), exc_info=True)
                        self.close_sock(sock)
                    except ConnectionError as e:
                        log.info("ConnectionError %s", sock)
                        log.debug(str(e), exc_info=True)
                        self.close_sock(sock)
                    except BlockingIOError:
                        log.debug("BlockingIOError %s", sock, exc_info=True)
                        self.close_sock(sock)
                    except Exception:
                        log.warning("Exception %s", sock, exc_info=True)
                        self.close_sock(sock)

        # closing
        log.debug("listening thread closing len=%d", len(self.socks))
        for sock in self.socks.copy():
            self.close_sock(sock, b"close sock pool now")
        log.info("listening thread close success")
        self.running.set()

    def add_sock(self, sock: Sock) -> None:
        """add new sock, client or server"""
        assert sock not in self.socks
        assert sock.secret_key is self.secret_key

        if sock.stype is SockType.SERVER:
            assert sock.others_key is None
        elif sock.stype is SockType.INBOUND:
            raise AssertionError("not allow inbound sock adding")
        elif sock.stype is SockType.OUTBOUND:
            pass
        else:
            raise NotImplementedError("unexpected sockType {}".format(sock))
        assert sock.stype is not SockType.INBOUND

        # note: maybe block max 0.2sec
        with self.lock:
            self.socks.append(sock)
        log.info("add new %s", sock)

    def get_sock(self, sock_id: int) -> Optional[Sock]:
        """get Sock from id"""
        with self.lock:
            for sock in self.socks:
                if sock_id == sock.id:
                    return sock
        return None

    def get_sock_from_socket(self, s: socket) -> Optional[Sock]:
        """get Sock from socket object"""
        with self.lock:
            for sock in self.socks:
                if sock is s:
                    return sock
        return None

    def close_sock(self, sock: Sock, reason: bytes = None) -> None:
        """graceful socket close"""
        try:
            # send reason if socket is working
            if isinstance(reason, bytes) and sock.stype is not SockType.SERVER:
                sock.sendall(reason, SockControl.CLOSING)
            # close socket is living yet
            if sock.fileno() != -1:
                sock.close()
        except ConnectionError as e:
            log.debug("ConnectionError", e)
        except Exception:
            log.debug("Exception", exc_info=True)

        # remove from list
        with self.lock:
            if sock in self.socks:
                self.socks.remove(sock)
        log.info("close %s", sock)

    def close(self) -> None:
        """close and wait complete status"""
        self.closing = True
        self.running.wait()
        self.closed = True


__all__ = [
    "CallbackFnc",
    "SockControl",
    "SockType",
    "CURVE",
    "UnrecoverableClose",
    "JobCorruptedWarning",
    "Sock",
    "SockPool",
]
