from typing import Optional, Dict, Any, List
import logging
from concurrent.futures import ThreadPoolExecutor
from rpc_gateway import errors, messages, gateway

logger = logging.getLogger(__name__)


class Server(gateway.GatewayClient):
    def __init__(self, gateway_url: str = 'ws://localhost:8888', max_workers: int = 5, instances: Optional[Dict[str, Any]] = None):
        super().__init__(gateway_url)
        self.instances = {} if instances is None else instances
        self.executor = ThreadPoolExecutor(max_workers=max_workers)

    async def _on_start(self):
        await self._register_gateway_instances(list(self.instances.keys()))

    async def _register_gateway_instances(self, instances: List[str]):
        await self.message_pump.send_request(messages.Method.REGISTER, instances)

    def register_instance(self, name: str, instance: Any) -> Any:
        if name in self.instances:
            raise errors.InstanceAlreadyRegisteredError(f'Instance already registered with name: {name}')

        self.instances[name] = instance

        return instance

    def get_instance(self, instance_name):
        try:
            return self.instances[instance_name]
        except KeyError:
            raise errors.InstanceNotFoundError(f'Instance not found: {instance_name}')

    def get(self, instance_name, attribute_name):
        try:
            instance = self.get_instance(instance_name)
            data = getattr(instance, attribute_name)

            if callable(data):
                return messages.Response(status=messages.Status.METHOD)

            return messages.Response(data=data)
        except Exception as err:
            return messages.Response(status=messages.Status.ERROR, data=err)

    def set(self, instance_name, attribute_name, value):
        try:
            instance = self.get_instance(instance_name)
            setattr(instance, attribute_name, value)
            return messages.Response()
        except Exception as err:
            return messages.Response(status=messages.Status.ERROR, data=err)

    def call(self, instance_name, attribute_name, args, kwargs):
        try:
            instance = self.get_instance(instance_name)
            method = getattr(instance, attribute_name)
            data = method(*args, **kwargs)
            return messages.Response(data=data)
        except Exception as err:
            return messages.Response(status=messages.Status.ERROR, data=err)

    async def _run(self, *args):
        return await self.event_loop.run_in_executor(self.executor, *args)

    async def _handle_request(self, request: messages.Request) -> messages.Response:
        if request.method == 'get':
            return await self._run(self.get, request.data['instance'], request.data['attribute'])

        if request.method == 'set':
            return await self._run(self.set, request.data['instance'], request.data['attribute'], request.data['value'])

        if request.method == 'call':
            return await self._run(self.call, request.data['instance'], request.data['attribute'], request.data['args'], request.data['kwargs'])

        return messages.Response(status=messages.Status.ERROR, data=errors.InvalidMethodError(f'Invalid method: {request.method}'))


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)

    class TestClass:
        foo = 'bar'

        def method(self):
            return 'baz'

    server = Server()
    server.register_instance('test', TestClass())
    server.start()