import time

from openmodule.config import settings
from openmodule.utils.backend import Backend
from openmodule.models.backend import AccessRequest, CountMessage, Access, SessionStartMessage, SessionFinishMessage
from openmodule.core import core
from openmodule.models.base import Gateway
from openmodule_test.core import OpenModuleCoreTestMixin
from openmodule_test.eventlistener import MockEvent
from openmodule_test.utils import DeveloperError


class TestBackend(Backend):
    __test__ = False  # otherwise pytest thinks this is a testcase

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.message_processed = MockEvent()

    checked_in = 0
    accessed = []

    def wait_for_message_process(self):
        """
        waits for at least one message to be processed
        """
        self.message_processed.wait_for_call()

    def check_in(self, message: CountMessage):
        if message.gateway.gate == "error":
            self.message_processed()
            raise Exception("test_exception")
        self.checked_in += 1
        self.message_processed()

    def check_out(self, message: CountMessage):
        if message.gateway.gate == "error":
            self.message_processed()
            raise Exception("test_exception")
        self.checked_in = self.checked_in - 1 if self.checked_in > 0 else 0
        self.message_processed()

    def check_in_session(self, message: SessionStartMessage):
        if message.entry_data.get("gate") == "error":
            self.message_processed()
            raise Exception("test_exception")
        self.checked_in += 1
        self.message_processed()

    def check_out_session(self, message: SessionFinishMessage):
        if message.exit_data.get("gate") == "error":
            self.message_processed()
            raise Exception("test_exception")
        self.checked_in += 1
        self.message_processed()

    def check_access(self, request: AccessRequest):
        self.accessed += [request.medium_id]
        if request.gateway and request.gateway.gate == "error":
            raise Exception("test_exception")
        if request.gateway and request.gateway.gate == "empty":
            return []
        return [Access(category="permanent-employee", start=time.time() - 1000,
                       end=time.time() + 3600 * 24 * 7, user="user")]


class BackendTestMixin(OpenModuleCoreTestMixin):
    """
    BackendTestMixin with helper functions for testing backends
    * set the backend_class
    """
    backend_class = None

    @classmethod
    def setUpClass(cls) -> None:
        assert cls.backend_class, "set a backend_class"
        return super().setUpClass()

    def setUp(self):
        super().setUp()
        self.backend = self.backend_class(core())
        self.transaction_count = 0
        self.messages_by_transaction = {}
        self.zone_count = {}

    def tearDown(self):
        self.backend.shutdown()
        super().tearDown()

    def check_auth(self, **kwargs):
        """
         standard check_auth call
        : param kwargs: Parameters to overwrite in the AccessRequest
        """
        kwargs.setdefault("name", settings.NAME)
        request = AccessRequest(**kwargs)
        return self.backend.rpc_check_access(request, None)

    def create_count_message(self, transaction_id=None, *, resource="arivotest", user="api-backend_user", gate="gate1",
                             direction="in", medium="lpr", medium_id="GARIVO1", zone="zone1",
                             category="permanent-employee", correction=False, error=None, **kwargs):
        """ creates standard count message
        :param transaction_id: Transaction ID for multiple sending of the same message, autogenerated if None
        :param resource: resource
        :param user: user
        :param gate: gate
        :param direction: direction
        :param medium: medium type (Medium)
        :param medium_id: id of medium
        :param zone: zone
        :param category: category (Category)
        :param correction: Flag if the message was only a correction
        :param error: Error if present, should not be used directly
        :param kwargs: Additional parameters for errors, should not be used directly
        :return: transaction_id
        """
        gateway = Gateway(gate=gate, direction=direction)
        if transaction_id is None:
            transaction_id = f"transaction-{self.transaction_count}"
            self.transaction_count += 1

        data = kwargs
        data.update(dict(user=user, gateway=gateway, medium_type=medium, id=medium_id, category=category,
                         real=not correction, transaction_id=transaction_id, count=0, error=error,
                         zone=zone, name="count", type="count", timestamp=time.time()), resource=resource)
        message = CountMessage(**data)
        self.messages_by_transaction[transaction_id] = message
        return transaction_id

    def check_in(self, transaction_id=None, send=True, **kwargs) -> str:
        """ standard check_in call
        :param transaction_id: Transaction ID of created count message
        :param send: Pass the check in call to the backend
        :param kwargs: message kwargs if transaction_id is None
        :return: transaction_id
        """
        if transaction_id is None:
            transaction_id = self.create_count_message(**kwargs)
        message = self.messages_by_transaction[transaction_id]

        self.zone_count[message.zone] = 1 + self.zone_count.get(message.zone, 0)
        message.count = self.zone_count[message.zone]
        message.gateway.direction = "in"
        if send:
            self.backend.check_in(message)
        return transaction_id

    def check_in_double_entry(self, transaction_id) -> str:
        """ check_in call with double entry
        :param transaction_id: previous entry transaction
        """

        old_message = self.messages_by_transaction[transaction_id]
        assert old_message.gateway.direction == "in", "needs previous check_in message"
        data = {k: v for k, v in old_message.dict().items() if k != "transaction_id"}
        data["error"] = "double_entry"
        data["previous_transaction_id"] = [transaction_id]
        new_transaction_id = self.create_count_message(**data)
        self.check_in(new_transaction_id)
        return new_transaction_id

    def check_out(self, transaction_id) -> CountMessage:
        """ standard check_out call
        :param transaction_id: Transaction ID of the check_in message
        """
        message = self.messages_by_transaction[transaction_id]
        message.gateway.direction = "out"
        message.timestamp = time.time()
        new_message = CountMessage(**message.dict())
        self.messages_by_transaction[transaction_id] = new_message
        self.zone_count[new_message.zone] = self.zone_count.get(new_message.zone, 1) - 1
        self.backend.check_out(new_message)
        self.messages_by_transaction.pop(transaction_id)
        return new_message

    def check_out_double_exit(self, **kwargs) -> CountMessage:
        """ check_out call with double exit
        :param kwargs: Parameters for the count message
        """

        kwargs["error"] = "double_exit"
        kwargs["direction"] = "out"
        transaction_id = self.create_count_message(**kwargs)
        message = self.messages_by_transaction[transaction_id]
        message.transaction_id = None
        return self.check_out(transaction_id)

    def check_out_medium_changed(self, transaction_id, medium="nfc", medium_id="mynfcid") -> CountMessage:
        """ check_out call with changed medium
            :param transaction_id: Transaction ID of the check_in message
            :param medium: new medium
            :param medium_id: new medium_id
            """
        message = self.messages_by_transaction[transaction_id]
        if message.medium_type == medium:
            raise DeveloperError("medium needs to change")
        message.error = "medium_type_changed"
        message.previous_medium_id = message.medium_id
        message.previous_medium_type = message.medium_type
        message.medium_type = medium
        message.medium_id = medium_id
        return self.check_out(transaction_id)

    def check_out_medium_id_changed(self, transaction_id, medium_id="mynfcid") -> CountMessage:
        """ check_out call with changed medium id
            :param transaction_id: Transaction ID of the check_in message
            :param medium_id: new medium_id
            """
        message = self.messages_by_transaction[transaction_id]
        if message.medium_id == medium_id:
            raise DeveloperError("medium_id needs to change")
        message.error = "medium_id_changed"
        message.previous_medium_id = message.medium_id
        message.medium_id = medium_id
        return self.check_out(transaction_id)

    def check_out_user_changed(self, transaction_id, user="api-backend_user1") -> CountMessage:
        """ check_out call with changed user
            :param transaction_id: Transaction ID of the check_in message
            :param user: new user
            """
        message = self.messages_by_transaction[transaction_id]
        message.previous_user = message.user
        if user == message.user:
            raise DeveloperError("user needs to change")
        message.user = user
        message.error = "user_changed"
        return self.check_out(transaction_id)

    def check_out_category_changed(self, transaction_id, category="filler-employee") -> CountMessage:
        """ check_out call with changed category
            :param transaction_id: Transaction ID of the check_in message
            :param category: new category
            """
        message = self.messages_by_transaction[transaction_id]
        if message.category == category:
            raise DeveloperError("category needs to change")
        message.category = category
        message.error = "category_changed"
        return self.check_out(transaction_id)
