import concurrent.futures as cf
import logging
from abc import ABC, abstractmethod

from ..core.utils import protocol_version_conversion
from ..network.Endpoint import Endpoint
from ..sockets.SocketAddress import SocketAddress

log = logging.getLogger(__name__)


class VulnerabilityTest(ABC):
    name = ""
    short_name = ""
    description = ""

    def __init__(self, supported_protocols, address, protocol):
        """
        Constructor

        :param SocketAddress address: Webserver address
        :param str protocol: SSL/TLS protocol
        :param list supported_protocols: Webservers supported SSL/TLS protocols
        """
        self.address = address
        self.protocol = protocol
        self.supported_protocols = supported_protocols
        self.valid_protocols = []
        self.scan_once = True
        self.usage = "vulnerability_scan"

    def scan(self):
        """
        Call the appropriate scan methods

        :return: Result of a scan
        :rtype: tuple
        """
        log.info(f"Testing for {self.name}")
        usable_protocols = [
            p for p in self.supported_protocols if p in self.valid_protocols
        ]
        if len(usable_protocols) == 0:
            log.error(f"No usable protocols found for {self.name}")
            return False
        switcher = {True: self.scan_for_protocol, False: self.scan_for_protocols}
        result = switcher[self.scan_once](usable_protocols)
        log.info(f"{self.name} test done")
        if type(result) == tuple:
            return result
        return result, ""

    def scan_for_protocol(self, usable_protocols):
        """
        Test the webserver for a vulnerability on a protocol versions

        :param: list usable_protocols
        :return: Whether a test was positive
        :rtype: bool or tuple
        """
        protocol = self.protocol
        if protocol not in self.valid_protocols:
            log.warning(
                f"{self.protocol} not valid for {self.name}, finding valid protocol"
            )
            protocol = Endpoint.worst_or_best_protocol(usable_protocols, False)
        result_tuple = self.test(protocol_version_conversion(protocol))
        return result_tuple

    def scan_for_protocols(self, usable_protocols):
        """
        Test the webserver for a vulnerability on all valid protocol versions

        :param: list usable_protocols
        :return: Whether any test was positive
        :rtype: bool or tuple
        """
        result_tuple: tuple
        self.run_once()
        if (
            "TLSv1.1" in self.supported_protocols
            and "TLSv1.0" in self.supported_protocols
        ):
            self.supported_protocols.remove("TLSv1.1")
        protocol_versions_numbers = list(
            map(protocol_version_conversion, usable_protocols)
        )
        if len(usable_protocols) == 1:
            result_tuple = self.test(protocol_versions_numbers[0])
        else:
            results = self.multithreading_tests(self.test, protocol_versions_numbers)
            result_tuple = self.parse_result_tuple(results)
        return result_tuple

    @staticmethod
    def parse_result_tuple(results):
        """
        Parse vulnerability test result

        Implemented test functions can either return just a bool value
        or a tuple containing a comment together with the test result

        :param list results: Test results
        :return: One tuple or one bool value
        :rtype: bool or tuple
        """
        result_tuples = list(filter(lambda r: type(r) == tuple, results))
        if len(result_tuples) == 0:
            return any(results)
        comments = []
        for result in result_tuples:
            if not result[0]:
                raise Exception(
                    "False test results can't have a comment attached to them"
                )
            comments.append(result[1])
        result_list = []
        if len(comments) > 0:
            result_list = [True]
            [result_list.append(comment) for comment in comments]
        result_tuple = tuple(result_list)
        return result_tuple

    @staticmethod
    def multithreading_tests(function, protocol_binary_versions):
        """
        Run tests in parallel

        :param list protocol_binary_versions:
        :param function: Functions to be run
        :return: Results of the scans
        :rtype: list
        """
        # List that all the threads live
        futures = []
        results = []
        log.info(
            f"Creating {len(protocol_binary_versions)} threads for vulnerability tests"
        )
        with cf.ThreadPoolExecutor(
            max_workers=len(protocol_binary_versions)
        ) as executor:
            for version in protocol_binary_versions:
                # Submit test function
                futures.append(executor.submit(function, version))
            for done_future in cf.as_completed(futures):
                results.append(done_future.result())
        return results

    @abstractmethod
    def test(self, version):
        """
        Test the implemented vulnerability

        :param int version: SSL/TLS version to test on
        :return: Whether the test passed or not
        :rtype: bool or tuple
        """
        pass

    def run_once(self):
        """
        Vulnerability test will implement this function if some action is needed before the testing

        e.g. Drown SSLv2 scanning
        """
        pass
