#######################################################
# 
# ReceiveConnections.py
# Python implementation of the Class ReceiveConnections
# Generated by Enterprise Architect
# Created on:      19-May-2020 6:21:05 PM
# Original author: Natha Paquette
# 
#######################################################
import asyncio
import socket
import logging
import logging.handlers
import re
import ssl
from lxml import etree
import time
import os
from typing import Union

from FreeTAKServer.core.configuration.ClientReceptionLoggingConstants import ClientReceptionLoggingConstants
from FreeTAKServer.core.configuration.LoggingConstants import LoggingConstants
from FreeTAKServer.services.ssl_cot_service.model.raw_ssl_connection_information import RawSSLConnectionInformation as sat
from FreeTAKServer.core.configuration.CreateLoggerController import CreateLoggerController
from FreeTAKServer.core.configuration.ReceiveConnectionsConstants import ReceiveConnectionsConstants
from FreeTAKServer.services.ssl_cot_service.controllers.SSLSocketController import SSLSocketController

loggingConstants = LoggingConstants(log_name="FTS_ReceiveConnections")
logger = CreateLoggerController("FTS_ReceiveConnections", logging_constants=loggingConstants).getLogger()

loggingConstants = ClientReceptionLoggingConstants()

# TODO: these should all be moved out to configuration
TEST_SUCCESS = "success"
END_OF_MESSAGE = b"</event>"
TIMEOUT_LENGTH = 60


# TODO: move health check values to constants and create controller for HealthCheck data

class ReceiveConnections:
    
    def receive_connection_data(self, client) -> Union[etree.Element, str]:
        """this method is responsible for receiving connection data from the client

        Args:
            client (socket.socket): _description_

        Raises:
            Exception: if data returned by client is empty

        Returns:
            Union[etree.Element, str]: in case of real connection an etree Element should be returned containing client connection data
                                        in case of test connection TEST_SUCCESS const should be returned
        """        
        client.settimeout(int(ReceiveConnectionsConstants().RECEIVECONNECTIONDATATIMEOUT))
        part = client.recv(1)
        if part == b"": raise Exception('empty data')
        client.settimeout(10)
        client.setblocking(True)
        xmlstring = self.recv_until(client, b"</event>").decode()
        if part.decode()+xmlstring == ReceiveConnectionsConstants().TESTDATA: return TEST_SUCCESS
        client.setblocking(True)
        client.settimeout(int(ReceiveConnectionsConstants().RECEIVECONNECTIONDATATIMEOUT))
        xmlstring = "<multiEvent>" + part.decode() + xmlstring + "</multiEvent>"  # convert to xmlstring wrapped by multiEvent tags
        xmlstring = re.sub(r'(?s)\<\?xml(.*)\?\>', '',
                           xmlstring)  # replace xml definition tag with empty string as it breaks serilization
        events = etree.fromstring(xmlstring)
        return events

    def listen(self, sock):
        # logger = CreateLoggerController("ReceiveConnections").getLogger()
        # listen for client connections
        sock.listen(ReceiveConnectionsConstants().LISTEN_COUNT)
        try:
            # establish the socket variables
            socket.setdefaulttimeout(ReceiveConnectionsConstants().SSL_SOCK_TIMEOUT)
            sock.settimeout(ReceiveConnectionsConstants().SSL_SOCK_TIMEOUT)
            # logger.debug('receive connection started')
            try:
                client, address = sock.accept()
                ssl_client = SSLSocketController().wrap_client_socket(client)
            except ssl.SSLError as ex:
                print(ex)
                self.disconnect_socket(client, ssl_client)
                logger.warning('ssl error thrown in connection attempt ' + str(ex))
                return -1

            except asyncio.TimeoutError as ex:
                self.disconnect_socket(client, ssl_client)
                logger.warning('timeout error thrown in connection attempt '+str(ex))
                return -1

            logger.info('client connected over ssl ' + str(address) + ' ' + str(time.time()))
            # wait to receive client
            try:
                events = self.receive_connection_data(client=ssl_client)
            except Exception:
                try:
                    events = self.receive_connection_data(client=ssl_client)
                except Exception as exb:
                    self.disconnect_socket(client, ssl_client)
                    logger.warning("receiving connection data from client failed with exception "+str(exb))
                    return -1
            # TODO: move out to separate function
            if events.text == TEST_SUCCESS:
                ssl_client.send(b'success')
            ssl_client.settimeout(0) # set the socket to non blocking
            logger.info(loggingConstants.RECEIVECONNECTIONSLISTENINFO)
            # establish the socket array containing important information about the client
            raw_connection_information = self.instantiate_client_object(address, client, ssl_client, events)
            logger.info("client accepted")
            try:
                if socket is not None and raw_connection_information.xmlString != b'':
                    return raw_connection_information
                else:
                    logger.warning("final socket entry is invalid")
                    self.disconnect_socket(client, ssl_client)
                    return -1
            except Exception as ex:
                self.disconnect_socket(client, ssl_client)
                logger.warning('exception in returning data ' + str(ex))
                return -1

        except Exception as ex:
            logger.warning(loggingConstants.RECEIVECONNECTIONSLISTENERROR)
            try:
                self.disconnect_socket(client, ssl_client)
            except Exception as ex:
                pass
            finally:
                return -1

    def disconnect_socket(self, client, ssl_client):
        ssl_client.shutdown(socket.SHUT_RDWR)
        ssl_client.close()
        client.close()

    def instantiate_client_object(self, address, unwrapped_client, client, events):
        raw_connection_information = sat()
        raw_connection_information.ip = address[0]
        raw_connection_information.socket = client
        raw_connection_information.unwrapped_sock = unwrapped_client
        raw_connection_information.xmlString = etree.tostring(events.findall('event')[0]).decode('utf-8')
        return raw_connection_information

    def recv_until(self, client, delimiter) -> bytes:
        """receive data until a delimiter has been reached

        Args:
            client (socket.socket): client socket
            delimiter (bytes): bytestring representing the delimiter

        Returns:
            Union[None, bytes]: None if no data was received otherwise send received data
        """        
        message = b""
        start_receive_time = time.time()
        client.settimeout(4)
        while delimiter not in message and time.time() - start_receive_time <= ReceiveConnectionsConstants().RECEIVECONNECTIONDATATIMEOUT:
            try:
                message = message + client.recv(ReceiveConnectionsConstants().CONNECTION_DATA_BUFFER)
            except:
                return message
        return message
