import base64
import datetime
import logging
import os
import socket
import sys
import threading
import time
from urllib import request

from cryptography.fernet import Fernet, InvalidToken
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import x25519
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
from cryptography.hazmat.primitives.serialization import (Encoding,
                                                          NoEncryption,
                                                          PrivateFormat,
                                                          PublicFormat,
                                                          load_pem_public_key)

aliases = {}
connected = None
fernet = None
private_key = None

COMMANDS = ["/alias", "/clear", "/config", "/help", "/ip", "/privkey",
            "/pubkey", "/quit", "/remote", "/sendfile", "/time", "/uptime"]
LOCAL_ALT_PORT = 4096
LOCAL_PORT = 2048
REMOTE_ALT_PORT = 4096
REMOTE_PORT = 2048
START_TIME = datetime.datetime.now()

class Server(threading.Thread):
    def accept_connection(self):
        """Accepts connection and derives a shared key."""
        global connected, fernet
        try:
            # exchange the ec public key
            peer_public_key = load_pem_public_key(self.peer.recv(4096))
            self.peer.sendall(private_key.public_key().public_bytes(
                Encoding.PEM, PublicFormat.SubjectPublicKeyInfo))
            shared_key = private_key.exchange(peer_public_key)
            derived_key = HKDFExpand(algorithm=hashes.SHA256(),
                                     length=32, info=None).derive(shared_key)
            fernet = Fernet(base64.urlsafe_b64encode(derived_key))
        except Exception as e:
            logging.error(str(e))

    def run(self):
        """Handles all of the incoming messages."""
        global connected, fernet, private_key
        while True:
            # generate a private key
            logging.info("Generating private key")
            private_key = x25519.X25519PrivateKey.generate()

            # listen for ipv4 connections on all hosts
            self.incoming = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            try:
                self.incoming.bind(("", LOCAL_PORT))
                logging.info(f"Listening on port {LOCAL_PORT}")
            except Exception as e:
                logging.error(str(e))
                logging.info("Trying alternate local")
                self.incoming.bind(("", LOCAL_ALT_PORT))
                logging.info(f"Listening on port {LOCAL_ALT_PORT}")

            # connect to peer automatically
            self.incoming.listen(1)
            self.peer, self.address = self.incoming.accept()
            if not connected:
                self.accept_connection()
                client.initate_connection(self.address[0], True)
                logging.info(f"New connection {self.address[0]}")
                logging.info(f"Press enter to continue")

            # listen for messages forever
            while True:
                try:
                    raw_message = fernet.decrypt(self.peer.recv(4096)).decode()
                    if raw_message != "FILE INCOMING":
                        message = f"{time.strftime('%H:%M:%S')}|{aliases.get(self.address[0], self.address[0])}: {raw_message}"
                        print(message)
                        logging.debug(message)
                    else:
                        # receive the entire file
                        self.peer.settimeout(2)
                        filename = fernet.decrypt(
                            self.peer.recv(4096)).decode()
                        logging.info(f"Receiving file {filename}")
                        buff = b""
                        while True:
                            try:
                                buff += self.peer.recv(4096)
                            except:
                                break
                        with open(filename, "wb") as file:
                            file.write(fernet.decrypt(buff))
                        logging.info(f"Received file {filename}")
                        self.peer.settimeout(None)
                except Exception as e:
                    if not str(e):
                        # empty string means peer disconnected
                        logging.info(
                            f"{aliases.get(self.address[0], self.address[0])} disconnected")
                        self.incoming.close()
                        connected = None
                        break
                    logging.error(str(e))
                    logging.info(
                        f"Error from {aliases.get(self.address[0], self.address[0])}")


class Client(threading.Thread):
    def alias(self, args):
        """Aliases an IP to a name"""
        aliases[args[1]] = args[2]
        logging.info(f"Aliased {args[1]} to {args[2]}")

    def clear(self, args):
        """Clears the console"""
        if os.name == "posix":
            os.system("clear")
        else:
            os.system("cls")
        logging.info("Console cleared")

    def config(self, args):
        """Shows aliases and ports."""
        for alias in aliases:
            logging.info(f"Aliasing {alias} to {aliases[alias]}")
        logging.info(f"Local ports: {LOCAL_PORT}, {LOCAL_ALT_PORT}")
        logging.info(f"Remote ports: {REMOTE_PORT}, {REMOTE_ALT_PORT}")

    def help(self, args):
        """Shows all commands or info"""
        if len(args) > 1 and "/" + args[1] in COMMANDS:
            logging.info(getattr(self, args[1]).__doc__)
        else:
            logging.info(" ".join(COMMANDS))

    def ip(self, args):
        """Shows local IP address"""
        logging.info(request.urlopen(
            "http://ipv4.icanhazip.com").read().decode("utf8").strip())

    def privkey(self, args):
        """Shows the local private key"""
        logging.info("Do not disclose this key")
        logging.info("\n" + private_key.private_bytes(
            Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()).decode().strip())

    def pubkey(self, args):
        """Shows the local public key"""
        logging.info("\n" + private_key.public_key().public_bytes(
            Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode().strip())

    def quit(self, args):
        """Quits the program"""
        client.outgoing.close()
        logging.info("Quit successfully")
        raise SystemExit

    def remote(self, args):
        """Shows connected IP address"""
        if connected:
            logging.info(f"Connected to {connected}")
        else:
            logging.info("Currently not connected")

    def sendfile(self, args):
        """Sends file over the network"""
        with open(args[1], "rb") as file:
            filename = args[1]
            logging.info(f"Sending file {filename}")
            self.outgoing.sendall(fernet.encrypt(b"FILE INCOMING"))
            self.outgoing.sendall(fernet.encrypt(filename.encode()))
            time.sleep(1)
            # sendall takes care of the batching
            self.outgoing.sendall(fernet.encrypt(file.read()))
        logging.info(f"Sent file {args[1]} successfully")

    def time(self, args):
        """Shows current local time"""
        logging.info(time.strftime("%d %b %Y %H:%M:%S"))

    def uptime(self, args):
        """Shows time since program start."""
        logging.info(str(datetime.datetime.now() - START_TIME))

    def initate_connection(self, target_host, no_exchange=False):
        """Tries the primary and alternate ports."""
        global connected, fernet
        # establish an initial connection
        try:
            self.outgoing.connect((target_host, REMOTE_PORT))
        except Exception as e:
            logging.error(str(e))
            logging.info("Trying alternate remote")
            try:
                self.outgoing.connect((target_host, REMOTE_ALT_PORT))
            except Exception as e:
                logging.error(str(e))
                return

        # setup connection from server thread
        connected = target_host
        if no_exchange:
            return

        # exchange the ec public key
        try:
            self.outgoing.sendall(private_key.public_key().public_bytes(
                Encoding.PEM, PublicFormat.SubjectPublicKeyInfo))
            peer_public_key = load_pem_public_key(self.outgoing.recv(4096))
            shared_key = private_key.exchange(peer_public_key)
            derived_key = HKDFExpand(algorithm=hashes.SHA256(),
                                     length=32, info=None).derive(shared_key)
            fernet = Fernet(base64.urlsafe_b64encode(derived_key))
        except Exception as e:
            logging.error(str(e))
            connected = None

    def run(self):
        """Handles all of the outgoing messages."""
        global connected, fernet
        # Connect to a specified peer
        logging.info(f"/help to list commands")
        while True:
            self.outgoing = socket.socket(
                socket.AF_INET, socket.SOCK_STREAM)
            while not connected:
                target_host = input("HOST: ")
                if target_host:
                    logging.info(f"Connecting to {target_host}")
                    self.initate_connection(target_host)
            logging.info(f"Connected to {connected}")

            try:
                # Either send message or run command
                while True:
                    message = input("")
                    if message:
                        print("\033[F", end="")
                        formatted_message = f"{time.strftime('%H:%M:%S')}|Local User: {message}"
                        print(formatted_message)
                        logging.debug(formatted_message)
                        check_command = message.split()
                        if check_command[0] in COMMANDS:
                            # hack to call function with name
                            try:
                                getattr(
                                    self, check_command[0][1:])(check_command)
                            except Exception as e:
                                logging.error(str(e))
                        else:
                            self.outgoing.sendall(
                                fernet.encrypt(message.encode()))
            except Exception as e:
                logging.error(str(e))
                connected = None


if __name__ == "__main__":
    # setup message output and logging
    handlers = [logging.StreamHandler(sys.stdout)]
    handlers[0].setLevel(logging.INFO)
    if input("Log? (y/n): ").lower() == "y":
        handlers.append(logging.FileHandler(filename='snakewhisper.log'))
        handlers[1].setFormatter(logging.Formatter(
            "%(asctime)s - %(levelname)s: %(message)s"))
    logging.basicConfig(level=logging.DEBUG,
                        format="%(levelname)s: %(message)s", handlers=handlers)

    # start the combined server and client
    server = Server()
    server.daemon = True
    server.start()
    time.sleep(1)
    client = Client()
    client.start()
