# coding=utf-8
from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import os
import socket
import struct
import sys
import threading
import time
import traceback

from apm_client.compat import queue
from apm_client.core.agent.commands import Register
from apm_client.core.agent.manager import get_socket_path, get_local_socket_path
from apm_client.core.config import scout_config
from apm_client.core.threading import SingletonThread

# Time unit - monkey-patched in tests to make them run faster
SECOND = 1

logger = logging.getLogger(__name__)


class CoreAgentSocketThread(SingletonThread):
    _instance_lock = threading.Lock()
    _stop_event = threading.Event()
    _command_queue = queue.Queue(maxsize=500)

    @classmethod
    def _on_stop(cls):
        super(CoreAgentSocketThread, cls)._on_stop()
        # Unblock _command_queue.get()
        try:
            cls._command_queue.put(None, False)
        except queue.Full:
            pass

    @classmethod
    def send(cls, command):
        try:
            cls._command_queue.put(command, False)

        except queue.Full as exc:
            # TODO mark the command as not queued?
            logger.debug("CoreAgentSocketThread error on send: %r", exc, exc_info=exc)

        cls.ensure_started()
        # celery_print('after send ensure started')

    @classmethod
    def wait_until_drained(cls, timeout_seconds=2.0, callback=None):
        interval_seconds = min(timeout_seconds, 0.05)
        start = time.time()
        while True:
            # send_to_me('wait until')
            queue_size = cls._command_queue.qsize()
            queue_empty = queue_size == 0
            elapsed = time.time() - start
            if queue_empty or elapsed >= timeout_seconds:
                break

            if callback is not None:
                callback(queue_size)
                callback = None

            cls.ensure_started()

            time.sleep(interval_seconds)
        return queue_empty

    def run(self):
        self.local_socket_path = get_local_socket_path()
        self.socket_path = get_socket_path()
        self.local_socket = self.make_local_socket()
        self.socket = self.make_socket()

        try:
            self._connect()
            self._register()
            while True:
                try:
                    body = self._command_queue.get(block=True, timeout=1 * SECOND)
                except queue.Empty:
                    body = None

                if body is not None:
                    result = self._send(body)
                    if result:
                        self._command_queue.task_done()
                    else:
                        # Something was wrong with the socket.

                        self._disconnect()
                        self._connect()
                        self._register()

                # Check for stop event after each read. This allows opening,
                # sending, and then immediately stopping. We do this for
                # the metadata event at application start time.
                if self._stop_event.is_set():
                    logger.debug("CoreAgentSocketThread stopping.")
                    break
        except Exception as exc:
            logger.debug("CoreAgentSocketThread exception: %r", exc, exc_info=exc)

        finally:
            self.local_socket.close()
            logger.debug("Local CoreAgentSocketThread stopped.")
            self.socket.close()
            logger.debug("Foreign CoreAgentSocketThread stopped.")

    def _send(self, command):
        msg = command.message()

        try:
            data = json.dumps(msg)
        except (ValueError, TypeError) as exc:
            logger.debug(
                "Exception when serializing command message: %r", exc, exc_info=exc
            )
            return False

        full_data = struct.pack(">I", len(data)) + data.encode("utf-8")

        local_core = False
        foreign_core = False
        if scout_config.value("monitor"):
            try:
                self.socket.sendall(full_data)
                foreign_core = True
            except OSError as exc:
                logger.debug(
                    (
                        "CoreAgentSocketThread exception on _send:"
                        + " %r on PID: %s on thread: %s"
                    ),
                    exc,
                    os.getpid(),
                    threading.current_thread(),
                    exc_info=exc,
                )

            # TODO do something with the response sent back in reply to command
            self._read_response()

        if scout_config.value('local_monitor'):
            try:
                self.local_socket.sendall(full_data)
                local_core = True
            except OSError as exc:
                logger.debug(
                    (
                        "CoreAgentSocketThread exception on _send:"
                        + " %r on PID: %s on thread: %s"
                    ),
                    exc,
                    os.getpid(),
                    threading.current_thread(),
                    exc_info=exc,
                )

        return local_core or foreign_core

    def _read_response(self):
        try:
            raw_size = self.socket.recv(4)
            if len(raw_size) != 4:
                # Ignore invalid responses
                return None
            size = struct.unpack(">I", raw_size)[0]
            message = bytearray(0)

            while len(message) < size:
                recv = self.socket.recv(size)
                message += recv

            return message
        except OSError as exc:
            logger.debug(
                "CoreAgentSocketThread error on read response: %r", exc, exc_info=exc
            )
            return None

    def _register(self):
        self._send(
            Register(
                app=scout_config.value("name"),
                key=scout_config.value("key"),
                hostname=scout_config.value("hostname"),
            )
        )

    def _connect(self, connect_attempts=5, retry_wait_secs=1):
        local_core_connect = False
        foreign_core_connect = False
        if scout_config.value("local_monitor"):
            for attempt in range(1, connect_attempts + 1):
                logger.debug(
                    (
                        "CoreAgentSocketThread attempt %d, connecting to %s, "
                        + "PID: %s, Thread: %s"
                    ),
                    attempt,
                    self.local_socket_path,
                    os.getpid(),
                    threading.current_thread(),
                )
                try:
                    self.local_socket.connect(self.get_local_socket_address())
                    self.local_socket.settimeout(3 * SECOND)
                    logger.debug("CoreAgentSocketThread connected")
                    local_core_connect = True
                except socket.error as exc:
                    # celery_print(str(exc))
                    logger.debug(
                        "CoreAgentSocketThread connection error: %r", exc, exc_info=exc
                    )
                    # Return without waiting when reaching the maximum number of attempts.
                    if attempt == connect_attempts:
                        continue
                    time.sleep(retry_wait_secs * SECOND)
        if scout_config.value("monitor"):
            for attempt in range(1, connect_attempts + 1):
                logger.debug(
                    (
                            "CoreAgentSocketThread attempt %d, connecting to %s, "
                            + "PID: %s, Thread: %s"
                    ),
                    attempt,
                    self.socket_path,
                    os.getpid(),
                    threading.current_thread(),
                )
                try:
                    self.socket.connect(self.get_socket_address())
                    self.socket.settimeout(3 * SECOND)
                    logger.debug("CoreAgentSocketThread connected")
                    foreign_core_connect = True
                except socket.error as exc:
                    # celery_print(str(exc))
                    logger.debug(
                        "CoreAgentSocketThread connection error: %r", exc, exc_info=exc
                    )
                    # Return without waiting when reaching the maximum number of attempts.
                    if attempt == connect_attempts:
                        continue
                    time.sleep(retry_wait_secs * SECOND)

        if not local_core_connect and not foreign_core_connect:
            raise

    def _disconnect(self):
        logger.debug("CoreAgentSocketThread disconnecting from %s", self.local_socket_path)
        try:
            self.local_socket.close()
        except socket.error as exc:
            logger.debug(
                "CoreAgentSocketThread exception on disconnect: %r", exc, exc_info=exc
            )
        finally:
            self.local_socket = self.make_local_socket()

        logger.debug("CoreAgentSocketThread disconnecting from %s", self.socket_path)
        try:
            self.socket.close()
        except socket.error as exc:
            logger.debug(
                "CoreAgentSocketThread exception on disconnect: %r", exc, exc_info=exc
            )
        finally:
            self.socket = self.make_socket()

    def make_local_socket(self):
        if self.local_socket_path.is_tcp:
            family = socket.AF_INET
        else:
            family = socket.AF_UNIX
        return socket.socket(family, socket.SOCK_STREAM)

    def make_socket(self):
        if self.socket_path.is_tcp:
            family = socket.AF_INET
        else:
            family = socket.AF_UNIX
        return socket.socket(family, socket.SOCK_STREAM)

    def get_local_socket_address(self):
        if self.local_socket_path.is_tcp:
            host, _, port = self.local_socket_path.tcp_address.partition(":")
            if sys.version_info[0] == 2:
                host = bytes(host)
            return host, int(port)
        return self.local_socket_path

    def get_socket_address(self):
        if self.socket_path.is_tcp:
            host, _, port = self.socket_path.tcp_address.partition(":")
            if sys.version_info[0] == 2:
                host = bytes(host)
            return host, int(port)
        return self.socket_path
